replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -9,39 +9,37 @@ if PYSPARK_AVAILABLE:
|
|
|
9
9
|
from replay.utils.session_handler import State
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class DataframeBucketizer(
|
|
13
|
-
Transformer, DefaultParamsWritable, DefaultParamsReadable
|
|
14
|
-
): # pylint: disable=R0901
|
|
12
|
+
class DataframeBucketizer(Transformer, DefaultParamsWritable, DefaultParamsReadable):
|
|
15
13
|
"""
|
|
16
14
|
Buckets the input dataframe, dumps it to spark warehouse directory,
|
|
17
15
|
and returns a bucketed dataframe.
|
|
18
16
|
"""
|
|
19
17
|
|
|
20
|
-
|
|
18
|
+
bucketing_key = Param(
|
|
21
19
|
Params._dummy(),
|
|
22
|
-
"
|
|
20
|
+
"bucketing_key",
|
|
23
21
|
"bucketing key (also used as sort key)",
|
|
24
22
|
typeConverter=TypeConverters.toString,
|
|
25
23
|
)
|
|
26
24
|
|
|
27
|
-
|
|
25
|
+
partition_num = Param(
|
|
28
26
|
Params._dummy(),
|
|
29
|
-
"
|
|
27
|
+
"partition_num",
|
|
30
28
|
"number of buckets",
|
|
31
29
|
typeConverter=TypeConverters.toInt,
|
|
32
30
|
)
|
|
33
31
|
|
|
34
|
-
|
|
32
|
+
table_name = Param(
|
|
35
33
|
Params._dummy(),
|
|
36
|
-
"
|
|
34
|
+
"table_name",
|
|
37
35
|
"parquet file name (for storage in 'spark-warehouse') and spark table name",
|
|
38
36
|
typeConverter=TypeConverters.toString,
|
|
39
37
|
)
|
|
40
38
|
|
|
41
|
-
|
|
39
|
+
spark_warehouse_dir = Param(
|
|
42
40
|
Params._dummy(),
|
|
43
|
-
"
|
|
44
|
-
"
|
|
41
|
+
"spark_warehouse_dir",
|
|
42
|
+
"spark_warehouse_dir",
|
|
45
43
|
typeConverter=TypeConverters.toString,
|
|
46
44
|
)
|
|
47
45
|
|
|
@@ -62,10 +60,10 @@ class DataframeBucketizer(
|
|
|
62
60
|
i.e. value of 'spark.sql.warehouse.dir' property
|
|
63
61
|
"""
|
|
64
62
|
super().__init__()
|
|
65
|
-
self.set(self.
|
|
66
|
-
self.set(self.
|
|
67
|
-
self.set(self.
|
|
68
|
-
self.set(self.
|
|
63
|
+
self.set(self.bucketing_key, bucketing_key)
|
|
64
|
+
self.set(self.partition_num, partition_num)
|
|
65
|
+
self.set(self.table_name, table_name)
|
|
66
|
+
self.set(self.spark_warehouse_dir, spark_warehouse_dir)
|
|
69
67
|
|
|
70
68
|
def __enter__(self):
|
|
71
69
|
return self
|
|
@@ -76,31 +74,27 @@ class DataframeBucketizer(
|
|
|
76
74
|
def remove_parquet(self):
|
|
77
75
|
"""Removes parquets where bucketed dataset is stored"""
|
|
78
76
|
spark = State().session
|
|
79
|
-
spark_warehouse_dir = self.getOrDefault(self.
|
|
80
|
-
table_name = self.getOrDefault(self.
|
|
81
|
-
fs = get_fs(spark)
|
|
82
|
-
fs_path = spark._jvm.org.apache.hadoop.fs.Path(
|
|
83
|
-
f"{spark_warehouse_dir}/{table_name}"
|
|
84
|
-
)
|
|
77
|
+
spark_warehouse_dir = self.getOrDefault(self.spark_warehouse_dir)
|
|
78
|
+
table_name = self.getOrDefault(self.table_name)
|
|
79
|
+
fs = get_fs(spark)
|
|
80
|
+
fs_path = spark._jvm.org.apache.hadoop.fs.Path(f"{spark_warehouse_dir}/{table_name}")
|
|
85
81
|
is_exists = fs.exists(fs_path)
|
|
86
82
|
if is_exists:
|
|
87
83
|
fs.delete(fs_path, True)
|
|
88
84
|
|
|
89
85
|
def set_table_name(self, table_name: str):
|
|
90
86
|
"""Sets table name"""
|
|
91
|
-
self.set(self.
|
|
87
|
+
self.set(self.table_name, table_name)
|
|
92
88
|
|
|
93
89
|
def _transform(self, dataset: SparkDataFrame):
|
|
94
|
-
bucketing_key = self.getOrDefault(self.
|
|
95
|
-
partition_num = self.getOrDefault(self.
|
|
96
|
-
table_name = self.getOrDefault(self.
|
|
97
|
-
spark_warehouse_dir = self.getOrDefault(self.
|
|
90
|
+
bucketing_key = self.getOrDefault(self.bucketing_key)
|
|
91
|
+
partition_num = self.getOrDefault(self.partition_num)
|
|
92
|
+
table_name = self.getOrDefault(self.table_name)
|
|
93
|
+
spark_warehouse_dir = self.getOrDefault(self.spark_warehouse_dir)
|
|
98
94
|
|
|
99
95
|
if not table_name:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
"Please set it via method 'set_table_name'."
|
|
103
|
-
)
|
|
96
|
+
msg = "Parameter 'table_name' is not set! Please set it via method 'set_table_name'."
|
|
97
|
+
raise ValueError(msg)
|
|
104
98
|
|
|
105
99
|
(
|
|
106
100
|
dataset.repartition(partition_num, bucketing_key)
|
replay/utils/distributions.py
CHANGED
|
@@ -22,23 +22,11 @@ def item_distribution(
|
|
|
22
22
|
:return: DataFrame with results
|
|
23
23
|
"""
|
|
24
24
|
log = convert2spark(log)
|
|
25
|
-
res = (
|
|
26
|
-
log.groupBy("item_idx")
|
|
27
|
-
.agg(sf.countDistinct("user_idx").alias("user_count"))
|
|
28
|
-
.select("item_idx", "user_count")
|
|
29
|
-
)
|
|
25
|
+
res = log.groupBy("item_idx").agg(sf.countDistinct("user_idx").alias("user_count")).select("item_idx", "user_count")
|
|
30
26
|
|
|
31
27
|
rec = convert2spark(recommendations)
|
|
32
28
|
rec = get_top_k_recs(rec, k)
|
|
33
|
-
rec = (
|
|
34
|
-
rec.groupBy("item_idx")
|
|
35
|
-
.agg(sf.countDistinct("user_idx").alias("rec_count"))
|
|
36
|
-
.select("item_idx", "rec_count")
|
|
37
|
-
)
|
|
29
|
+
rec = rec.groupBy("item_idx").agg(sf.countDistinct("user_idx").alias("rec_count")).select("item_idx", "rec_count")
|
|
38
30
|
|
|
39
|
-
res = (
|
|
40
|
-
res.join(rec, on="item_idx", how="outer")
|
|
41
|
-
.fillna(0)
|
|
42
|
-
.orderBy(["user_count", "item_idx"])
|
|
43
|
-
)
|
|
31
|
+
res = res.join(rec, on="item_idx", how="outer").fillna(0).orderBy(["user_count", "item_idx"])
|
|
44
32
|
return spark_to_pandas(res, allow_collect_to_master)
|
replay/utils/model_handler.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
import functools
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
4
|
import pickle
|
|
5
|
+
import warnings
|
|
5
6
|
from os.path import join
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import Union
|
|
8
|
+
from typing import Any, Callable, Optional, Union
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
10
11
|
from replay.models import *
|
|
11
12
|
from replay.models.base_rec import BaseRecommender
|
|
12
13
|
from replay.splitters import *
|
|
13
|
-
from .session_handler import State
|
|
14
14
|
|
|
15
|
+
from .session_handler import State
|
|
15
16
|
from .types import PYSPARK_AVAILABLE
|
|
16
17
|
|
|
17
18
|
if PYSPARK_AVAILABLE:
|
|
@@ -26,9 +27,7 @@ if PYSPARK_AVAILABLE:
|
|
|
26
27
|
:param spark: spark session
|
|
27
28
|
:return:
|
|
28
29
|
"""
|
|
29
|
-
fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(
|
|
30
|
-
spark._jsc.hadoopConfiguration()
|
|
31
|
-
)
|
|
30
|
+
fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(spark._jsc.hadoopConfiguration())
|
|
32
31
|
return fs
|
|
33
32
|
|
|
34
33
|
def get_list_of_paths(spark: SparkSession, dir_path: str):
|
|
@@ -44,9 +43,7 @@ if PYSPARK_AVAILABLE:
|
|
|
44
43
|
return [str(f.getPath()) for f in statuses]
|
|
45
44
|
|
|
46
45
|
|
|
47
|
-
def save(
|
|
48
|
-
model: BaseRecommender, path: Union[str, Path], overwrite: bool = False
|
|
49
|
-
):
|
|
46
|
+
def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False):
|
|
50
47
|
"""
|
|
51
48
|
Save fitted model to disk as a folder
|
|
52
49
|
|
|
@@ -63,9 +60,8 @@ def save(
|
|
|
63
60
|
if not overwrite:
|
|
64
61
|
is_exists = fs.exists(spark._jvm.org.apache.hadoop.fs.Path(path))
|
|
65
62
|
if is_exists:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
)
|
|
63
|
+
msg = f"Path '{path}' already exists. Mode is 'overwrite = False'."
|
|
64
|
+
raise FileExistsError(msg)
|
|
69
65
|
|
|
70
66
|
fs.mkdirs(spark._jvm.org.apache.hadoop.fs.Path(path))
|
|
71
67
|
model._save_model(join(path, "model"))
|
|
@@ -74,9 +70,7 @@ def save(
|
|
|
74
70
|
init_args["_model_name"] = str(model)
|
|
75
71
|
sc = spark.sparkContext
|
|
76
72
|
df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
|
|
77
|
-
df.coalesce(1).write.mode("overwrite").option(
|
|
78
|
-
"ignoreNullFields", "false"
|
|
79
|
-
).json(join(path, "init_args.json"))
|
|
73
|
+
df.coalesce(1).write.mode("overwrite").option("ignoreNullFields", "false").json(join(path, "init_args.json"))
|
|
80
74
|
|
|
81
75
|
dataframes = model._dataframes
|
|
82
76
|
df_path = join(path, "dataframes")
|
|
@@ -85,13 +79,9 @@ def save(
|
|
|
85
79
|
df.write.mode("overwrite").parquet(join(df_path, name))
|
|
86
80
|
|
|
87
81
|
if hasattr(model, "fit_queries"):
|
|
88
|
-
model.fit_queries.write.mode("overwrite").parquet(
|
|
89
|
-
join(df_path, "fit_queries")
|
|
90
|
-
)
|
|
82
|
+
model.fit_queries.write.mode("overwrite").parquet(join(df_path, "fit_queries"))
|
|
91
83
|
if hasattr(model, "fit_items"):
|
|
92
|
-
model.fit_items.write.mode("overwrite").parquet(
|
|
93
|
-
join(df_path, "fit_items")
|
|
94
|
-
)
|
|
84
|
+
model.fit_items.write.mode("overwrite").parquet(join(df_path, "fit_items"))
|
|
95
85
|
if hasattr(model, "study"):
|
|
96
86
|
save_picklable_to_parquet(model.study, join(path, "study"))
|
|
97
87
|
|
|
@@ -104,18 +94,11 @@ def load(path: str, model_type=None) -> BaseRecommender:
|
|
|
104
94
|
:return: Restored trained model
|
|
105
95
|
"""
|
|
106
96
|
spark = State().session
|
|
107
|
-
args = (
|
|
108
|
-
spark.read.json(join(path, "init_args.json"))
|
|
109
|
-
.first()
|
|
110
|
-
.asDict(recursive=True)
|
|
111
|
-
)
|
|
97
|
+
args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
|
|
112
98
|
name = args["_model_name"]
|
|
113
99
|
del args["_model_name"]
|
|
114
100
|
|
|
115
|
-
if model_type is not None
|
|
116
|
-
model_class = model_type
|
|
117
|
-
else:
|
|
118
|
-
model_class = globals()[name]
|
|
101
|
+
model_class = model_type if model_type is not None else globals()[name]
|
|
119
102
|
|
|
120
103
|
model = model_class(**args)
|
|
121
104
|
|
|
@@ -180,9 +163,7 @@ def save_splitter(splitter: Splitter, path: str, overwrite: bool = False):
|
|
|
180
163
|
sc = spark.sparkContext
|
|
181
164
|
df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
|
|
182
165
|
if overwrite:
|
|
183
|
-
df.coalesce(1).write.mode("overwrite").json(
|
|
184
|
-
join(path, "init_args.json")
|
|
185
|
-
)
|
|
166
|
+
df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json"))
|
|
186
167
|
else:
|
|
187
168
|
df.coalesce(1).write.json(join(path, "init_args.json"))
|
|
188
169
|
|
|
@@ -200,3 +181,25 @@ def load_splitter(path: str) -> Splitter:
|
|
|
200
181
|
del args["_splitter_name"]
|
|
201
182
|
splitter = globals()[name]
|
|
202
183
|
return splitter(**args)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
|
|
187
|
+
"""
|
|
188
|
+
Decorator that throws deprecation warnings.
|
|
189
|
+
|
|
190
|
+
:param message: message to deprecation warning without func name.
|
|
191
|
+
"""
|
|
192
|
+
base_msg = "will be deprecated in future versions."
|
|
193
|
+
|
|
194
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
195
|
+
@functools.wraps(func)
|
|
196
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
197
|
+
msg = f"{func.__qualname__} {message if message else base_msg}"
|
|
198
|
+
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
|
199
|
+
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
|
|
200
|
+
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
|
201
|
+
return func(*args, **kwargs)
|
|
202
|
+
|
|
203
|
+
return wrapper
|
|
204
|
+
|
|
205
|
+
return decorator
|
replay/utils/session_handler.py
CHANGED
|
@@ -36,7 +36,6 @@ def get_spark_session(
|
|
|
36
36
|
Default: ``None``.
|
|
37
37
|
"""
|
|
38
38
|
if os.environ.get("SCRIPT_ENV", None) == "cluster": # pragma: no cover
|
|
39
|
-
# pylint: disable=no-member
|
|
40
39
|
return SparkSession.builder.getOrCreate()
|
|
41
40
|
|
|
42
41
|
os.environ["PYSPARK_PYTHON"] = sys.executable
|
|
@@ -46,33 +45,32 @@ def get_spark_session(
|
|
|
46
45
|
path_to_replay_jar = os.environ.get("REPLAY_JAR_PATH")
|
|
47
46
|
else:
|
|
48
47
|
if pyspark_version.startswith("3.1"): # pragma: no cover
|
|
49
|
-
path_to_replay_jar =
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
):
|
|
48
|
+
path_to_replay_jar = (
|
|
49
|
+
"https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
|
|
50
|
+
)
|
|
51
|
+
elif pyspark_version.startswith(("3.2", "3.3")):
|
|
53
52
|
path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.2.0_als_metrics/replay_2.12-3.2.0_als_metrics.jar"
|
|
54
53
|
elif pyspark_version.startswith("3.4"): # pragma: no cover
|
|
55
54
|
path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.4.0_als_metrics/replay_2.12-3.4.0_als_metrics.jar"
|
|
56
55
|
else: # pragma: no cover
|
|
57
|
-
path_to_replay_jar =
|
|
56
|
+
path_to_replay_jar = (
|
|
57
|
+
"https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
|
|
58
|
+
)
|
|
58
59
|
logging.warning(
|
|
59
|
-
"Replay ALS model support only spark 3.1-3.4 versions! "
|
|
60
|
-
"
|
|
60
|
+
"Replay ALS model support only spark 3.1-3.4 versions! Replay will use "
|
|
61
|
+
"'https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar' "
|
|
62
|
+
"in 'spark.jars' property."
|
|
61
63
|
)
|
|
62
64
|
|
|
63
65
|
if core_count is None: # checking out env variable
|
|
64
66
|
core_count = int(os.environ.get("REPLAY_SPARK_CORE_COUNT", "-1"))
|
|
65
67
|
if spark_memory is None:
|
|
66
68
|
env_var = os.environ.get("REPLAY_SPARK_MEMORY")
|
|
67
|
-
if env_var is not None
|
|
68
|
-
spark_memory = int(env_var)
|
|
69
|
-
else: # pragma: no cover
|
|
70
|
-
spark_memory = floor(psutil.virtual_memory().total / 1024**3 * 0.7)
|
|
69
|
+
spark_memory = int(env_var) if env_var is not None else floor(psutil.virtual_memory().total / 1024**3 * 0.7)
|
|
71
70
|
if shuffle_partitions is None:
|
|
72
71
|
shuffle_partitions = os.cpu_count() * 3
|
|
73
72
|
driver_memory = f"{spark_memory}g"
|
|
74
73
|
user_home = os.environ["HOME"]
|
|
75
|
-
# pylint: disable=no-member
|
|
76
74
|
spark = (
|
|
77
75
|
SparkSession.builder.config("spark.driver.memory", driver_memory)
|
|
78
76
|
.config(
|
|
@@ -111,7 +109,6 @@ def logger_with_settings() -> logging.Logger:
|
|
|
111
109
|
return logger
|
|
112
110
|
|
|
113
111
|
|
|
114
|
-
# pylint: disable=too-few-public-methods
|
|
115
112
|
class Borg:
|
|
116
113
|
"""
|
|
117
114
|
This class allows to share objects between instances.
|
|
@@ -123,7 +120,6 @@ class Borg:
|
|
|
123
120
|
self.__dict__ = self._shared_state
|
|
124
121
|
|
|
125
122
|
|
|
126
|
-
# pylint: disable=too-few-public-methods
|
|
127
123
|
class State(Borg):
|
|
128
124
|
"""
|
|
129
125
|
All modules look for Spark session via this class. You can put your own session here.
|
replay/utils/spark_utils.py
CHANGED
|
@@ -10,14 +10,17 @@ import pandas as pd
|
|
|
10
10
|
from numpy.random import default_rng
|
|
11
11
|
|
|
12
12
|
from .session_handler import State
|
|
13
|
-
|
|
14
13
|
from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, SparkDataFrame
|
|
15
14
|
|
|
16
15
|
if PYSPARK_AVAILABLE:
|
|
17
16
|
import pyspark.sql.types as st
|
|
18
17
|
from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT
|
|
19
|
-
from pyspark.sql import
|
|
20
|
-
|
|
18
|
+
from pyspark.sql import (
|
|
19
|
+
Column,
|
|
20
|
+
SparkSession,
|
|
21
|
+
Window,
|
|
22
|
+
functions as sf,
|
|
23
|
+
)
|
|
21
24
|
from pyspark.sql.column import _to_java_column, _to_seq
|
|
22
25
|
from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
|
|
23
26
|
else:
|
|
@@ -48,7 +51,6 @@ def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False)
|
|
|
48
51
|
return data.toPandas()
|
|
49
52
|
|
|
50
53
|
|
|
51
|
-
# pylint: disable=invalid-name
|
|
52
54
|
def convert2spark(data_frame: Optional[DataFrameLike]) -> Optional[SparkDataFrame]:
|
|
53
55
|
"""
|
|
54
56
|
Converts Pandas DataFrame to Spark DataFrame
|
|
@@ -61,7 +63,7 @@ def convert2spark(data_frame: Optional[DataFrameLike]) -> Optional[SparkDataFram
|
|
|
61
63
|
if isinstance(data_frame, SparkDataFrame):
|
|
62
64
|
return data_frame
|
|
63
65
|
spark = State().session
|
|
64
|
-
return spark.createDataFrame(data_frame)
|
|
66
|
+
return spark.createDataFrame(data_frame)
|
|
65
67
|
|
|
66
68
|
|
|
67
69
|
def get_top_k(
|
|
@@ -76,7 +78,11 @@ def get_top_k(
|
|
|
76
78
|
|
|
77
79
|
>>> from replay.utils.session_handler import State
|
|
78
80
|
>>> spark = State().session
|
|
79
|
-
>>> log =
|
|
81
|
+
>>> log = (
|
|
82
|
+
... spark
|
|
83
|
+
... .createDataFrame([(1, 2, 1.), (1, 3, 1.), (1, 4, 0.5), (2, 1, 1.)])
|
|
84
|
+
... .toDF("user_id", "item_id", "relevance")
|
|
85
|
+
... )
|
|
80
86
|
>>> log.show()
|
|
81
87
|
+-------+-------+---------+
|
|
82
88
|
|user_id|item_id|relevance|
|
|
@@ -108,9 +114,7 @@ def get_top_k(
|
|
|
108
114
|
return (
|
|
109
115
|
dataframe.withColumn(
|
|
110
116
|
"temp_rank",
|
|
111
|
-
sf.row_number().over(
|
|
112
|
-
Window.partitionBy(partition_by_col).orderBy(*order_by_col)
|
|
113
|
-
),
|
|
117
|
+
sf.row_number().over(Window.partitionBy(partition_by_col).orderBy(*order_by_col)),
|
|
114
118
|
)
|
|
115
119
|
.filter(sf.col("temp_rank") <= k)
|
|
116
120
|
.drop("temp_rank")
|
|
@@ -141,6 +145,7 @@ def get_top_k_recs(
|
|
|
141
145
|
|
|
142
146
|
|
|
143
147
|
if PYSPARK_AVAILABLE:
|
|
148
|
+
|
|
144
149
|
@sf.udf(returnType=st.DoubleType())
|
|
145
150
|
def vector_dot(one: DenseVector, two: DenseVector) -> float: # pragma: no cover
|
|
146
151
|
"""
|
|
@@ -179,10 +184,8 @@ if PYSPARK_AVAILABLE:
|
|
|
179
184
|
"""
|
|
180
185
|
return float(one.dot(two))
|
|
181
186
|
|
|
182
|
-
@sf.udf(returnType=VectorUDT())
|
|
183
|
-
def vector_mult(
|
|
184
|
-
one: Union[DenseVector, NumType], two: DenseVector
|
|
185
|
-
) -> DenseVector: # pragma: no cover
|
|
187
|
+
@sf.udf(returnType=VectorUDT())
|
|
188
|
+
def vector_mult(one: Union[DenseVector, NumType], two: DenseVector) -> DenseVector: # pragma: no cover
|
|
186
189
|
"""
|
|
187
190
|
elementwise vector multiplication
|
|
188
191
|
|
|
@@ -271,9 +274,7 @@ def multiply_scala_udf(scalar, vector):
|
|
|
271
274
|
return Column(_f.apply(_to_seq(sc, [scalar, vector], _to_java_column)))
|
|
272
275
|
|
|
273
276
|
|
|
274
|
-
def get_log_info(
|
|
275
|
-
log: SparkDataFrame, user_col="user_idx", item_col="item_idx"
|
|
276
|
-
) -> str:
|
|
277
|
+
def get_log_info(log: SparkDataFrame, user_col="user_idx", item_col="item_idx") -> str:
|
|
277
278
|
"""
|
|
278
279
|
Basic log statistics
|
|
279
280
|
|
|
@@ -310,9 +311,7 @@ def get_log_info(
|
|
|
310
311
|
)
|
|
311
312
|
|
|
312
313
|
|
|
313
|
-
def get_stats(
|
|
314
|
-
log: SparkDataFrame, group_by: str = "user_id", target_column: str = "relevance"
|
|
315
|
-
) -> SparkDataFrame:
|
|
314
|
+
def get_stats(log: SparkDataFrame, group_by: str = "user_id", target_column: str = "relevance") -> SparkDataFrame:
|
|
316
315
|
"""
|
|
317
316
|
Calculate log statistics: min, max, mean, median ratings, number of ratings.
|
|
318
317
|
>>> from replay.utils.session_handler import get_spark_session, State
|
|
@@ -351,14 +350,9 @@ def get_stats(
|
|
|
351
350
|
"count": sf.count,
|
|
352
351
|
}
|
|
353
352
|
agg_functions_list = [
|
|
354
|
-
func(target_column).alias(str(name + "_" + target_column))
|
|
355
|
-
for name, func in agg_functions.items()
|
|
353
|
+
func(target_column).alias(str(name + "_" + target_column)) for name, func in agg_functions.items()
|
|
356
354
|
]
|
|
357
|
-
agg_functions_list.append(
|
|
358
|
-
sf.expr(f"percentile_approx({target_column}, 0.5)").alias(
|
|
359
|
-
"median_" + target_column
|
|
360
|
-
)
|
|
361
|
-
)
|
|
355
|
+
agg_functions_list.append(sf.expr(f"percentile_approx({target_column}, 0.5)").alias("median_" + target_column))
|
|
362
356
|
|
|
363
357
|
return log.groupBy(group_by).agg(*agg_functions_list)
|
|
364
358
|
|
|
@@ -369,13 +363,9 @@ def check_numeric(feature_table: SparkDataFrame) -> None:
|
|
|
369
363
|
:param feature_table: spark DataFrame
|
|
370
364
|
"""
|
|
371
365
|
for column in feature_table.columns:
|
|
372
|
-
if not isinstance(
|
|
373
|
-
feature_table.schema[column].dataType,
|
|
374
|
-
|
|
375
|
-
raise ValueError(
|
|
376
|
-
f"""Column {column} has type {feature_table.schema[
|
|
377
|
-
column].dataType}, that is not numeric."""
|
|
378
|
-
)
|
|
366
|
+
if not isinstance(feature_table.schema[column].dataType, st.NumericType):
|
|
367
|
+
msg = f"Column {column} has type {feature_table.schema[column].dataType}, that is not numeric."
|
|
368
|
+
raise ValueError(msg)
|
|
379
369
|
|
|
380
370
|
|
|
381
371
|
def horizontal_explode(
|
|
@@ -420,10 +410,7 @@ def horizontal_explode(
|
|
|
420
410
|
num_columns = len(data_frame.select(column_to_explode).head()[0])
|
|
421
411
|
return data_frame.select(
|
|
422
412
|
*other_columns,
|
|
423
|
-
*[
|
|
424
|
-
sf.element_at(column_to_explode, i + 1).alias(f"{prefix}_{i}")
|
|
425
|
-
for i in range(num_columns)
|
|
426
|
-
],
|
|
413
|
+
*[sf.element_at(column_to_explode, i + 1).alias(f"{prefix}_{i}") for i in range(num_columns)],
|
|
427
414
|
)
|
|
428
415
|
|
|
429
416
|
|
|
@@ -442,7 +429,6 @@ def join_or_return(first, second, on, how):
|
|
|
442
429
|
return first.join(second, on=on, how=how)
|
|
443
430
|
|
|
444
431
|
|
|
445
|
-
# pylint: disable=too-many-arguments
|
|
446
432
|
def fallback(
|
|
447
433
|
base: SparkDataFrame,
|
|
448
434
|
fill: SparkDataFrame,
|
|
@@ -471,15 +457,11 @@ def fallback(
|
|
|
471
457
|
diff = max_in_fill - min_in_base
|
|
472
458
|
fill = fill.withColumnRenamed(rating_column, "relevance_fallback")
|
|
473
459
|
if diff >= 0:
|
|
474
|
-
fill = fill.withColumn(
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
fill, on=[query_column, item_column], how="full_outer"
|
|
460
|
+
fill = fill.withColumn("relevance_fallback", sf.col("relevance_fallback") - diff - margin)
|
|
461
|
+
recs = base.join(fill, on=[query_column, item_column], how="full_outer")
|
|
462
|
+
recs = recs.withColumn(rating_column, sf.coalesce(rating_column, "relevance_fallback")).select(
|
|
463
|
+
query_column, item_column, rating_column
|
|
479
464
|
)
|
|
480
|
-
recs = recs.withColumn(
|
|
481
|
-
rating_column, sf.coalesce(rating_column, "relevance_fallback")
|
|
482
|
-
).select(query_column, item_column, rating_column)
|
|
483
465
|
recs = get_top_k_recs(recs, k, query_column=query_column, rating_column=rating_column)
|
|
484
466
|
return recs
|
|
485
467
|
|
|
@@ -537,9 +519,7 @@ def join_with_col_renaming(
|
|
|
537
519
|
right = right.withColumnRenamed(name, f"{name}_{suffix}")
|
|
538
520
|
on_condition &= sf.col(name) == sf.col(f"{name}_{suffix}")
|
|
539
521
|
|
|
540
|
-
return (left.join(right, on=on_condition, how=how)).drop(
|
|
541
|
-
*[f"{name}_{suffix}" for name in on_col_name]
|
|
542
|
-
)
|
|
522
|
+
return (left.join(right, on=on_condition, how=how)).drop(*[f"{name}_{suffix}" for name in on_col_name])
|
|
543
523
|
|
|
544
524
|
|
|
545
525
|
def process_timestamp_column(
|
|
@@ -562,7 +542,8 @@ def process_timestamp_column(
|
|
|
562
542
|
:return: dataframe with updated column ``column_name``
|
|
563
543
|
"""
|
|
564
544
|
if column_name not in dataframe.columns:
|
|
565
|
-
|
|
545
|
+
msg = f"Column {column_name} not found"
|
|
546
|
+
raise ValueError(msg)
|
|
566
547
|
|
|
567
548
|
# no conversion needed
|
|
568
549
|
if isinstance(dataframe.schema[column_name].dataType, st.TimestampType):
|
|
@@ -570,9 +551,7 @@ def process_timestamp_column(
|
|
|
570
551
|
|
|
571
552
|
# unix timestamp
|
|
572
553
|
if isinstance(dataframe.schema[column_name].dataType, st.NumericType):
|
|
573
|
-
return dataframe.withColumn(
|
|
574
|
-
column_name, sf.to_timestamp(sf.from_unixtime(sf.col(column_name)))
|
|
575
|
-
)
|
|
554
|
+
return dataframe.withColumn(column_name, sf.to_timestamp(sf.from_unixtime(sf.col(column_name))))
|
|
576
555
|
|
|
577
556
|
# datetime in string format
|
|
578
557
|
dataframe = dataframe.withColumn(
|
|
@@ -583,6 +562,7 @@ def process_timestamp_column(
|
|
|
583
562
|
|
|
584
563
|
|
|
585
564
|
if PYSPARK_AVAILABLE:
|
|
565
|
+
|
|
586
566
|
@sf.udf(returnType=VectorUDT())
|
|
587
567
|
def list_to_vector_udf(array: st.ArrayType) -> DenseVector: # pragma: no cover
|
|
588
568
|
"""
|
|
@@ -603,9 +583,7 @@ if PYSPARK_AVAILABLE:
|
|
|
603
583
|
return float(first.squared_distance(second))
|
|
604
584
|
|
|
605
585
|
@sf.udf(returnType=st.FloatType())
|
|
606
|
-
def vector_euclidean_distance_similarity(
|
|
607
|
-
first: DenseVector, second: DenseVector
|
|
608
|
-
) -> float: # pragma: no cover
|
|
586
|
+
def vector_euclidean_distance_similarity(first: DenseVector, second: DenseVector) -> float: # pragma: no cover
|
|
609
587
|
"""
|
|
610
588
|
:param first: first vector
|
|
611
589
|
:param second: second vector
|
|
@@ -642,7 +620,7 @@ def drop_temp_view(temp_view_name: str) -> None:
|
|
|
642
620
|
spark.catalog.dropTempView(temp_view_name)
|
|
643
621
|
|
|
644
622
|
|
|
645
|
-
def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: int = None):
|
|
623
|
+
def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: Optional[int] = None):
|
|
646
624
|
"""
|
|
647
625
|
Sample k items for each user with probability proportional to the relevance score.
|
|
648
626
|
|
|
@@ -660,17 +638,13 @@ def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: int = None):
|
|
|
660
638
|
"""
|
|
661
639
|
pairs = pairs.withColumn(
|
|
662
640
|
"probability",
|
|
663
|
-
sf.col("relevance")
|
|
664
|
-
/ sf.sum("relevance").over(Window.partitionBy("user_idx")),
|
|
641
|
+
sf.col("relevance") / sf.sum("relevance").over(Window.partitionBy("user_idx")),
|
|
665
642
|
)
|
|
666
643
|
|
|
667
644
|
def grouped_map(pandas_df: pd.DataFrame) -> pd.DataFrame: # pragma: no cover
|
|
668
645
|
user_idx = pandas_df["user_idx"][0]
|
|
669
646
|
|
|
670
|
-
if seed is not None
|
|
671
|
-
local_rng = default_rng(seed + user_idx)
|
|
672
|
-
else:
|
|
673
|
-
local_rng = default_rng()
|
|
647
|
+
local_rng = default_rng(seed + user_idx) if seed is not None else default_rng()
|
|
674
648
|
|
|
675
649
|
items_positions = local_rng.choice(
|
|
676
650
|
np.arange(pandas_df.shape[0]),
|
|
@@ -686,6 +660,7 @@ def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: int = None):
|
|
|
686
660
|
"relevance": pandas_df["relevance"].values[items_positions],
|
|
687
661
|
}
|
|
688
662
|
)
|
|
663
|
+
|
|
689
664
|
rec_schema = StructType(
|
|
690
665
|
[
|
|
691
666
|
StructField("user_idx", IntegerType()),
|
|
@@ -716,19 +691,12 @@ def filter_cold(
|
|
|
716
691
|
if df is None:
|
|
717
692
|
return 0, df
|
|
718
693
|
|
|
719
|
-
num_cold = (
|
|
720
|
-
df.select(col_name)
|
|
721
|
-
.distinct()
|
|
722
|
-
.join(warm_df, on=col_name, how="anti")
|
|
723
|
-
.count()
|
|
724
|
-
)
|
|
694
|
+
num_cold = df.select(col_name).distinct().join(warm_df, on=col_name, how="anti").count()
|
|
725
695
|
|
|
726
696
|
if num_cold == 0:
|
|
727
697
|
return 0, df
|
|
728
698
|
|
|
729
|
-
return num_cold, df.join(
|
|
730
|
-
warm_df.select(col_name), on=col_name, how="inner"
|
|
731
|
-
)
|
|
699
|
+
return num_cold, df.join(warm_df.select(col_name), on=col_name, how="inner")
|
|
732
700
|
|
|
733
701
|
|
|
734
702
|
def get_unique_entities(
|
|
@@ -745,17 +713,14 @@ def get_unique_entities(
|
|
|
745
713
|
if isinstance(df, SparkDataFrame):
|
|
746
714
|
unique = df.select(column).distinct()
|
|
747
715
|
elif isinstance(df, collections.abc.Iterable):
|
|
748
|
-
unique = spark.createDataFrame(
|
|
749
|
-
data=pd.DataFrame(pd.unique(list(df)), columns=[column])
|
|
750
|
-
)
|
|
716
|
+
unique = spark.createDataFrame(data=pd.DataFrame(pd.unique(list(df)), columns=[column]))
|
|
751
717
|
else:
|
|
752
|
-
|
|
718
|
+
msg = f"Wrong type {type(df)}"
|
|
719
|
+
raise ValueError(msg)
|
|
753
720
|
return unique
|
|
754
721
|
|
|
755
722
|
|
|
756
|
-
def return_recs(
|
|
757
|
-
recs: SparkDataFrame, recs_file_path: Optional[str] = None
|
|
758
|
-
) -> Optional[SparkDataFrame]:
|
|
723
|
+
def return_recs(recs: SparkDataFrame, recs_file_path: Optional[str] = None) -> Optional[SparkDataFrame]:
|
|
759
724
|
"""
|
|
760
725
|
Save dataframe `recs` to `recs_file_path` if presents otherwise cache
|
|
761
726
|
and materialize the dataframe.
|
|
@@ -785,7 +750,7 @@ def save_picklable_to_parquet(obj: Any, path: str) -> None:
|
|
|
785
750
|
sc = State().session.sparkContext
|
|
786
751
|
# We can use `RDD.saveAsPickleFile`, but it has no "overwrite" parameter
|
|
787
752
|
pickled_instance = pickle.dumps(obj)
|
|
788
|
-
Record = collections.namedtuple("Record", ["data"])
|
|
753
|
+
Record = collections.namedtuple("Record", ["data"]) # noqa: PYI024
|
|
789
754
|
rdd = sc.parallelize([Record(pickled_instance)])
|
|
790
755
|
instance_df = rdd.map(lambda rec: Record(bytearray(rec.data))).toDF()
|
|
791
756
|
instance_df.write.mode("overwrite").parquet(path)
|
|
@@ -812,9 +777,10 @@ def assert_omp_single_thread():
|
|
|
812
777
|
PyTorch uses multithreading for cpu math operations via OpenMP library. Sometimes this
|
|
813
778
|
leads to failures when OpenMP multithreading is mixed with multiprocessing.
|
|
814
779
|
"""
|
|
815
|
-
omp_num_threads = os.environ.get(
|
|
816
|
-
if omp_num_threads !=
|
|
817
|
-
|
|
818
|
-
'Environment variable "OMP_NUM_THREADS" is set to "
|
|
819
|
-
|
|
780
|
+
omp_num_threads = os.environ.get("OMP_NUM_THREADS", None)
|
|
781
|
+
if omp_num_threads != "1":
|
|
782
|
+
msg = (
|
|
783
|
+
f'Environment variable "OMP_NUM_THREADS" is set to "{omp_num_threads}". '
|
|
784
|
+
f"Set it to 1 if the working process freezes."
|
|
820
785
|
)
|
|
786
|
+
logging.getLogger("replay").warning(msg)
|