replay-rec 0.17.0rc0__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +246 -20
- replay/data/nn/schema.py +42 -0
- replay/data/nn/sequence_tokenizer.py +17 -47
- replay/data/nn/sequential_dataset.py +76 -2
- replay/preprocessing/filters.py +169 -4
- replay/splitters/base_splitter.py +1 -1
- replay/utils/common.py +107 -5
- replay/utils/spark_utils.py +13 -6
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/METADATA +3 -11
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/RECORD +13 -66
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -61
- replay/experimental/metrics/base_metric.py +0 -601
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -10
- replay/experimental/models/admm_slim.py +0 -205
- replay/experimental/models/base_neighbour_rec.py +0 -204
- replay/experimental/models/base_rec.py +0 -1271
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -452
- replay/experimental/models/ddpg.py +0 -921
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -265
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -302
- replay/experimental/models/mult_vae.py +0 -331
- replay/experimental/models/neuromf.py +0 -405
- replay/experimental/models/scala_als.py +0 -296
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -55
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -838
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -181
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.17.0rc0.dist-info/NOTICE +0 -41
- {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/LICENSE +0 -0
replay/preprocessing/filters.py
CHANGED
|
@@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from datetime import datetime, timedelta
|
|
6
6
|
from typing import Callable, Optional, Tuple, Union
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
8
10
|
import polars as pl
|
|
9
11
|
|
|
10
12
|
from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
|
|
@@ -357,7 +359,7 @@ class NumInteractionsFilter(_BaseFilter):
|
|
|
357
359
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
358
360
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
359
361
|
... )
|
|
360
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
362
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
361
363
|
>>> log_sp = convert2spark(log_pd)
|
|
362
364
|
>>> log_sp.show()
|
|
363
365
|
+-------+-------+------+-------------------+
|
|
@@ -499,7 +501,7 @@ class EntityDaysFilter(_BaseFilter):
|
|
|
499
501
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
500
502
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
501
503
|
... )
|
|
502
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
504
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
503
505
|
>>> log_sp = convert2spark(log_pd)
|
|
504
506
|
>>> log_sp.orderBy('user_id', 'item_id').show()
|
|
505
507
|
+-------+-------+------+-------------------+
|
|
@@ -638,7 +640,7 @@ class GlobalDaysFilter(_BaseFilter):
|
|
|
638
640
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
639
641
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
640
642
|
... )
|
|
641
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
643
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
642
644
|
>>> log_sp = convert2spark(log_pd)
|
|
643
645
|
>>> log_sp.show()
|
|
644
646
|
+-------+-------+------+-------------------+
|
|
@@ -740,7 +742,7 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
740
742
|
... "2020-02-01", "2020-01-01 00:04:15",
|
|
741
743
|
... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
|
|
742
744
|
... )
|
|
743
|
-
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
|
|
745
|
+
>>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
|
|
744
746
|
>>> log_sp = convert2spark(log_pd)
|
|
745
747
|
>>> log_sp.show()
|
|
746
748
|
+-------+-------+------+-------------------+
|
|
@@ -823,3 +825,166 @@ class TimePeriodFilter(_BaseFilter):
|
|
|
823
825
|
return interactions.filter(
|
|
824
826
|
pl.col(self.timestamp_column).is_between(self.start_date, self.end_date, closed="left")
|
|
825
827
|
)
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
class QuantileItemsFilter(_BaseFilter):
|
|
831
|
+
"""
|
|
832
|
+
Filter is aimed on undersampling the interactions dataset.
|
|
833
|
+
|
|
834
|
+
Filter algorithm performs undersampling by removing `items_proportion` of interactions
|
|
835
|
+
for each items counts that exceeds the `alpha_quantile` value in distribution. Filter firstly
|
|
836
|
+
removes popular items (items that have most interactions). Filter also keeps the original
|
|
837
|
+
relation of items popularity among each other by removing interactions only in range of
|
|
838
|
+
current item count and quantile count (specified by `alpha_quantile`).
|
|
839
|
+
|
|
840
|
+
>>> import pandas as pd
|
|
841
|
+
>>> from replay.utils.spark_utils import convert2spark
|
|
842
|
+
>>> log_pd = pd.DataFrame({
|
|
843
|
+
... "user_id": [0, 0, 1, 2, 2, 2, 2],
|
|
844
|
+
... "item_id": [0, 2, 1, 1, 2, 2, 2]
|
|
845
|
+
... })
|
|
846
|
+
>>> log_spark = convert2spark(log_pd)
|
|
847
|
+
>>> log_spark.show()
|
|
848
|
+
+-------+-------+
|
|
849
|
+
|user_id|item_id|
|
|
850
|
+
+-------+-------+
|
|
851
|
+
| 0| 0|
|
|
852
|
+
| 0| 2|
|
|
853
|
+
| 1| 1|
|
|
854
|
+
| 2| 1|
|
|
855
|
+
| 2| 2|
|
|
856
|
+
| 2| 2|
|
|
857
|
+
| 2| 2|
|
|
858
|
+
+-------+-------+
|
|
859
|
+
<BLANKLINE>
|
|
860
|
+
|
|
861
|
+
>>> QuantileItemsFilter(query_column="user_id").transform(log_spark).show()
|
|
862
|
+
+-------+-------+
|
|
863
|
+
|user_id|item_id|
|
|
864
|
+
+-------+-------+
|
|
865
|
+
| 0| 0|
|
|
866
|
+
| 1| 1|
|
|
867
|
+
| 2| 1|
|
|
868
|
+
| 2| 2|
|
|
869
|
+
| 2| 2|
|
|
870
|
+
| 0| 2|
|
|
871
|
+
+-------+-------+
|
|
872
|
+
<BLANKLINE>
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
def __init__(
|
|
876
|
+
self,
|
|
877
|
+
alpha_quantile: float = 0.99,
|
|
878
|
+
items_proportion: float = 0.5,
|
|
879
|
+
query_column: str = "query_id",
|
|
880
|
+
item_column: str = "item_id",
|
|
881
|
+
) -> None:
|
|
882
|
+
"""
|
|
883
|
+
:param alpha_quantile: Quantile value of items counts distribution to keep unchanged.
|
|
884
|
+
Every items count that exceeds this value will be undersampled.
|
|
885
|
+
Default: ``0.99``.
|
|
886
|
+
:param items_proportion: proportion of items counts to remove for items that
|
|
887
|
+
exceeds `alpha_quantile` value in range of current item count and quantile count
|
|
888
|
+
to make sure we keep original relation between items unchanged.
|
|
889
|
+
Default: ``0.5``.
|
|
890
|
+
:param query_column: query column name.
|
|
891
|
+
Default: ``query_id``.
|
|
892
|
+
:param item_column: item column name.
|
|
893
|
+
Default: ``item_id``.
|
|
894
|
+
"""
|
|
895
|
+
if not 0 < alpha_quantile < 1:
|
|
896
|
+
msg = "`alpha_quantile` value must be in (0, 1)"
|
|
897
|
+
raise ValueError(msg)
|
|
898
|
+
if not 0 < items_proportion < 1:
|
|
899
|
+
msg = "`items_proportion` value must be in (0, 1)"
|
|
900
|
+
raise ValueError(msg)
|
|
901
|
+
|
|
902
|
+
self.alpha_quantile = alpha_quantile
|
|
903
|
+
self.items_proportion = items_proportion
|
|
904
|
+
self.query_column = query_column
|
|
905
|
+
self.item_column = item_column
|
|
906
|
+
|
|
907
|
+
def _filter_pandas(self, df: pd.DataFrame):
|
|
908
|
+
items_distribution = df.groupby(self.item_column).size().reset_index().rename(columns={0: "counts"})
|
|
909
|
+
users_distribution = df.groupby(self.query_column).size().reset_index().rename(columns={0: "counts"})
|
|
910
|
+
count_threshold = items_distribution.loc[:, "counts"].quantile(self.alpha_quantile, interpolation="midpoint")
|
|
911
|
+
df_with_counts = df.merge(items_distribution, how="left", on=self.item_column).merge(
|
|
912
|
+
users_distribution, how="left", on=self.query_column, suffixes=["_items", "_users"]
|
|
913
|
+
)
|
|
914
|
+
long_tail = df_with_counts.loc[df_with_counts["counts_items"] <= count_threshold]
|
|
915
|
+
short_tail = df_with_counts.loc[df_with_counts["counts_items"] > count_threshold]
|
|
916
|
+
short_tail["num_items_to_delete"] = self.items_proportion * (
|
|
917
|
+
short_tail["counts_items"] - long_tail["counts_items"].max()
|
|
918
|
+
)
|
|
919
|
+
short_tail["num_items_to_delete"] = short_tail["num_items_to_delete"].astype("int")
|
|
920
|
+
short_tail = short_tail.sort_values("counts_users", ascending=False)
|
|
921
|
+
|
|
922
|
+
def get_mask(x):
|
|
923
|
+
mask = np.ones_like(x)
|
|
924
|
+
threshold = x.iloc[0]
|
|
925
|
+
mask[:threshold] = 0
|
|
926
|
+
return mask
|
|
927
|
+
|
|
928
|
+
mask = short_tail.groupby(self.item_column)["num_items_to_delete"].transform(get_mask).astype(bool)
|
|
929
|
+
return pd.concat([long_tail[df.columns], short_tail.loc[mask][df.columns]])
|
|
930
|
+
|
|
931
|
+
def _filter_polars(self, df: pl.DataFrame):
|
|
932
|
+
items_distribution = df.group_by(self.item_column).len()
|
|
933
|
+
users_distribution = df.group_by(self.query_column).len()
|
|
934
|
+
count_threshold = items_distribution.select("len").quantile(self.alpha_quantile, "midpoint")["len"][0]
|
|
935
|
+
df_with_counts = (
|
|
936
|
+
df.join(items_distribution, how="left", on=self.item_column).join(
|
|
937
|
+
users_distribution, how="left", on=self.query_column
|
|
938
|
+
)
|
|
939
|
+
).rename({"len": "counts_items", "len_right": "counts_users"})
|
|
940
|
+
long_tail = df_with_counts.filter(pl.col("counts_items") <= count_threshold)
|
|
941
|
+
short_tail = df_with_counts.filter(pl.col("counts_items") > count_threshold)
|
|
942
|
+
max_long_tail_count = long_tail["counts_items"].max()
|
|
943
|
+
items_to_delete = (
|
|
944
|
+
short_tail.select(
|
|
945
|
+
self.query_column,
|
|
946
|
+
self.item_column,
|
|
947
|
+
self.items_proportion * (pl.col("counts_items") - max_long_tail_count),
|
|
948
|
+
)
|
|
949
|
+
.with_columns(pl.col("literal").cast(pl.Int64).alias("num_items_to_delete"))
|
|
950
|
+
.select(self.item_column, "num_items_to_delete")
|
|
951
|
+
.unique(maintain_order=True)
|
|
952
|
+
)
|
|
953
|
+
short_tail = short_tail.join(items_to_delete, how="left", on=self.item_column).sort(
|
|
954
|
+
"counts_users", descending=True
|
|
955
|
+
)
|
|
956
|
+
short_tail = short_tail.with_columns(index=pl.int_range(short_tail.shape[0]))
|
|
957
|
+
grouped = short_tail.group_by(self.item_column, maintain_order=True).agg(
|
|
958
|
+
pl.col("index"), pl.col("num_items_to_delete")
|
|
959
|
+
)
|
|
960
|
+
grouped = grouped.with_columns(
|
|
961
|
+
pl.col("num_items_to_delete").list.get(0),
|
|
962
|
+
(pl.col("index").list.len() - pl.col("num_items_to_delete").list.get(0)).alias("tail"),
|
|
963
|
+
)
|
|
964
|
+
grouped = grouped.with_columns(pl.col("index").list.tail(pl.col("tail")))
|
|
965
|
+
grouped = grouped.explode("index").select("index")
|
|
966
|
+
short_tail = grouped.join(short_tail, how="left", on="index")
|
|
967
|
+
return pl.concat([long_tail.select(df.columns), short_tail.select(df.columns)])
|
|
968
|
+
|
|
969
|
+
def _filter_spark(self, df: SparkDataFrame):
|
|
970
|
+
items_distribution = df.groupBy(self.item_column).agg(sf.count(self.query_column).alias("counts_items"))
|
|
971
|
+
users_distribution = df.groupBy(self.query_column).agg(sf.count(self.item_column).alias("counts_users"))
|
|
972
|
+
count_threshold = items_distribution.toPandas().loc[:, "counts_items"].quantile(self.alpha_quantile, "midpoint")
|
|
973
|
+
df_with_counts = df.join(items_distribution, on=self.item_column).join(users_distribution, on=self.query_column)
|
|
974
|
+
long_tail = df_with_counts.filter(sf.col("counts_items") <= count_threshold)
|
|
975
|
+
short_tail = df_with_counts.filter(sf.col("counts_items") > count_threshold)
|
|
976
|
+
max_long_tail_count = long_tail.agg({"counts_items": "max"}).collect()[0][0]
|
|
977
|
+
items_to_delete = (
|
|
978
|
+
short_tail.withColumn(
|
|
979
|
+
"num_items_to_delete",
|
|
980
|
+
(self.items_proportion * (sf.col("counts_items") - max_long_tail_count)).cast("int"),
|
|
981
|
+
)
|
|
982
|
+
.select(self.item_column, "num_items_to_delete")
|
|
983
|
+
.distinct()
|
|
984
|
+
)
|
|
985
|
+
short_tail = short_tail.join(items_to_delete, on=self.item_column, how="left")
|
|
986
|
+
short_tail = short_tail.withColumn(
|
|
987
|
+
"index", sf.row_number().over(Window.partitionBy(self.item_column).orderBy(sf.col("counts_users").desc()))
|
|
988
|
+
)
|
|
989
|
+
short_tail = short_tail.filter(sf.col("index") > sf.col("num_items_to_delete"))
|
|
990
|
+
return long_tail.select(df.columns).union(short_tail.select(df.columns))
|
replay/utils/common.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
1
3
|
import json
|
|
2
4
|
from pathlib import Path
|
|
3
|
-
from typing import Union
|
|
5
|
+
from typing import Any, Callable, Union
|
|
4
6
|
|
|
7
|
+
from polars import from_pandas as pl_from_pandas
|
|
8
|
+
|
|
9
|
+
from replay.data.dataset import Dataset
|
|
5
10
|
from replay.splitters import (
|
|
6
11
|
ColdUserRandomSplitter,
|
|
7
12
|
KFolds,
|
|
@@ -12,7 +17,16 @@ from replay.splitters import (
|
|
|
12
17
|
TimeSplitter,
|
|
13
18
|
TwoStageSplitter,
|
|
14
19
|
)
|
|
15
|
-
from replay.utils import
|
|
20
|
+
from replay.utils import (
|
|
21
|
+
TORCH_AVAILABLE,
|
|
22
|
+
PandasDataFrame,
|
|
23
|
+
PolarsDataFrame,
|
|
24
|
+
SparkDataFrame,
|
|
25
|
+
)
|
|
26
|
+
from replay.utils.spark_utils import (
|
|
27
|
+
convert2spark as pandas_to_spark,
|
|
28
|
+
spark_to_pandas,
|
|
29
|
+
)
|
|
16
30
|
|
|
17
31
|
SavableObject = Union[
|
|
18
32
|
ColdUserRandomSplitter,
|
|
@@ -23,10 +37,11 @@ SavableObject = Union[
|
|
|
23
37
|
RatioSplitter,
|
|
24
38
|
TimeSplitter,
|
|
25
39
|
TwoStageSplitter,
|
|
40
|
+
Dataset,
|
|
26
41
|
]
|
|
27
42
|
|
|
28
43
|
if TORCH_AVAILABLE:
|
|
29
|
-
from replay.data.nn import SequenceTokenizer
|
|
44
|
+
from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer
|
|
30
45
|
|
|
31
46
|
SavableObject = Union[
|
|
32
47
|
ColdUserRandomSplitter,
|
|
@@ -38,6 +53,8 @@ if TORCH_AVAILABLE:
|
|
|
38
53
|
TimeSplitter,
|
|
39
54
|
TwoStageSplitter,
|
|
40
55
|
SequenceTokenizer,
|
|
56
|
+
PandasSequentialDataset,
|
|
57
|
+
PolarsSequentialDataset,
|
|
41
58
|
]
|
|
42
59
|
|
|
43
60
|
|
|
@@ -50,7 +67,7 @@ def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
|
|
|
50
67
|
obj.save(path)
|
|
51
68
|
|
|
52
69
|
|
|
53
|
-
def load_from_replay(path: Union[str, Path]) -> SavableObject:
|
|
70
|
+
def load_from_replay(path: Union[str, Path], **kwargs) -> SavableObject:
|
|
54
71
|
"""
|
|
55
72
|
General function to load RePlay models, splitters and tokenizer.
|
|
56
73
|
|
|
@@ -60,6 +77,91 @@ def load_from_replay(path: Union[str, Path]) -> SavableObject:
|
|
|
60
77
|
with open(path / "init_args.json", "r") as file:
|
|
61
78
|
class_name = json.loads(file.read())["_class_name"]
|
|
62
79
|
obj_type = globals()[class_name]
|
|
63
|
-
obj = obj_type.load(path)
|
|
80
|
+
obj = obj_type.load(path, **kwargs)
|
|
64
81
|
|
|
65
82
|
return obj
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _check_if_dataframe(var: Any):
|
|
86
|
+
if not isinstance(var, (SparkDataFrame, PolarsDataFrame, PandasDataFrame)):
|
|
87
|
+
msg = f"Object of type {type(var)} is not a dataframe of known type (can be pandas|spark|polars)"
|
|
88
|
+
raise ValueError(msg)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def check_if_dataframe(*args_to_check: str) -> Callable[..., Any]:
|
|
92
|
+
def decorator_func(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
93
|
+
@functools.wraps(func)
|
|
94
|
+
def wrap_func(*args: Any, **kwargs: Any) -> Any:
|
|
95
|
+
extended_kwargs = {}
|
|
96
|
+
extended_kwargs.update(kwargs)
|
|
97
|
+
extended_kwargs.update(dict(zip(inspect.signature(func).parameters.keys(), args)))
|
|
98
|
+
# add default param values to dict with arguments
|
|
99
|
+
extended_kwargs.update(
|
|
100
|
+
{
|
|
101
|
+
x.name: x.default
|
|
102
|
+
for x in inspect.signature(func).parameters.values()
|
|
103
|
+
if x.name not in extended_kwargs and x.default is not x.empty
|
|
104
|
+
}
|
|
105
|
+
)
|
|
106
|
+
vals_to_check = [extended_kwargs[_arg] for _arg in args_to_check]
|
|
107
|
+
for val in vals_to_check:
|
|
108
|
+
_check_if_dataframe(val)
|
|
109
|
+
return func(*args, **kwargs)
|
|
110
|
+
|
|
111
|
+
return wrap_func
|
|
112
|
+
|
|
113
|
+
return decorator_func
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@check_if_dataframe("data")
|
|
117
|
+
def convert2pandas(
|
|
118
|
+
data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
|
|
119
|
+
) -> PandasDataFrame:
|
|
120
|
+
"""
|
|
121
|
+
Convert the spark|polars DataFrame to a pandas.DataFrame.
|
|
122
|
+
Returns unchanged dataframe if the input is already of type pandas.DataFrame.
|
|
123
|
+
|
|
124
|
+
:param data: The dataframe to convert. Can be polars|spark|pandas DataFrame.
|
|
125
|
+
:param allow_collect_to_master: If set to False (default) raises a warning
|
|
126
|
+
about collecting parallelized data to the master node.
|
|
127
|
+
"""
|
|
128
|
+
if isinstance(data, PandasDataFrame):
|
|
129
|
+
return data
|
|
130
|
+
if isinstance(data, PolarsDataFrame):
|
|
131
|
+
return data.to_pandas()
|
|
132
|
+
if isinstance(data, SparkDataFrame):
|
|
133
|
+
return spark_to_pandas(data, allow_collect_to_master, from_constructor=False)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@check_if_dataframe("data")
|
|
137
|
+
def convert2polars(
|
|
138
|
+
data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
|
|
139
|
+
) -> PolarsDataFrame:
|
|
140
|
+
"""
|
|
141
|
+
Convert the spark|pandas DataFrame to a polars.DataFrame.
|
|
142
|
+
Returns unchanged dataframe if the input is already of type polars.DataFrame.
|
|
143
|
+
|
|
144
|
+
:param data: The dataframe to convert. Can be spark|pandas|polars DataFrame.
|
|
145
|
+
:param allow_collect_to_master: If set to False (default) raises a warning
|
|
146
|
+
about collecting parallelized data to the master node.
|
|
147
|
+
"""
|
|
148
|
+
if isinstance(data, PandasDataFrame):
|
|
149
|
+
return pl_from_pandas(data)
|
|
150
|
+
if isinstance(data, PolarsDataFrame):
|
|
151
|
+
return data
|
|
152
|
+
if isinstance(data, SparkDataFrame):
|
|
153
|
+
return pl_from_pandas(spark_to_pandas(data, allow_collect_to_master, from_constructor=False))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@check_if_dataframe("data")
|
|
157
|
+
def convert2spark(data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame]) -> SparkDataFrame:
|
|
158
|
+
"""
|
|
159
|
+
Convert the pandas|polars DataFrame to a pysaprk.sql.DataFrame.
|
|
160
|
+
Returns unchanged dataframe if the input is already of type pysaprk.sql.DataFrame.
|
|
161
|
+
|
|
162
|
+
:param data: The dataframe to convert. Can be pandas|polars|spark Datarame.
|
|
163
|
+
"""
|
|
164
|
+
if isinstance(data, (PandasDataFrame, SparkDataFrame)):
|
|
165
|
+
return pandas_to_spark(data)
|
|
166
|
+
if isinstance(data, PolarsDataFrame):
|
|
167
|
+
return pandas_to_spark(data.to_pandas())
|
replay/utils/spark_utils.py
CHANGED
|
@@ -33,7 +33,9 @@ class SparkCollectToMasterWarning(Warning): # pragma: no cover
|
|
|
33
33
|
"""
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def spark_to_pandas(
|
|
36
|
+
def spark_to_pandas(
|
|
37
|
+
data: SparkDataFrame, allow_collect_to_master: bool = False, from_constructor: bool = True
|
|
38
|
+
) -> pd.DataFrame: # pragma: no cover
|
|
37
39
|
"""
|
|
38
40
|
Convert Spark DataFrame to Pandas DataFrame.
|
|
39
41
|
|
|
@@ -42,10 +44,15 @@ def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False)
|
|
|
42
44
|
|
|
43
45
|
:returns: Converted Pandas DataFrame.
|
|
44
46
|
"""
|
|
47
|
+
warn_msg = "Spark Data Frame is collected to master node, this may lead to OOM exception for larger dataset. "
|
|
48
|
+
if from_constructor:
|
|
49
|
+
_msg = "To remove this warning set allow_collect_to_master=True in the recommender constructor."
|
|
50
|
+
else:
|
|
51
|
+
_msg = "To remove this warning set allow_collect_to_master=True."
|
|
52
|
+
warn_msg += _msg
|
|
45
53
|
if not allow_collect_to_master:
|
|
46
54
|
warnings.warn(
|
|
47
|
-
|
|
48
|
-
"To remove this warning set allow_collect_to_master=True in the recommender constructor.",
|
|
55
|
+
warn_msg,
|
|
49
56
|
SparkCollectToMasterWarning,
|
|
50
57
|
)
|
|
51
58
|
return data.toPandas()
|
|
@@ -169,7 +176,7 @@ if PYSPARK_AVAILABLE:
|
|
|
169
176
|
<BLANKLINE>
|
|
170
177
|
>>> output_data = input_data.select(vector_dot("one", "two").alias("dot"))
|
|
171
178
|
>>> output_data.schema
|
|
172
|
-
StructType(
|
|
179
|
+
StructType([StructField('dot', DoubleType(), True)])
|
|
173
180
|
>>> output_data.show()
|
|
174
181
|
+----+
|
|
175
182
|
| dot|
|
|
@@ -207,7 +214,7 @@ if PYSPARK_AVAILABLE:
|
|
|
207
214
|
<BLANKLINE>
|
|
208
215
|
>>> output_data = input_data.select(vector_mult("one", "two").alias("mult"))
|
|
209
216
|
>>> output_data.schema
|
|
210
|
-
StructType(
|
|
217
|
+
StructType([StructField('mult', VectorUDT(), True)])
|
|
211
218
|
>>> output_data.show()
|
|
212
219
|
+---------+
|
|
213
220
|
| mult|
|
|
@@ -244,7 +251,7 @@ if PYSPARK_AVAILABLE:
|
|
|
244
251
|
<BLANKLINE>
|
|
245
252
|
>>> output_data = input_data.select(array_mult("one", "two").alias("mult"))
|
|
246
253
|
>>> output_data.schema
|
|
247
|
-
StructType(
|
|
254
|
+
StructType([StructField('mult', ArrayType(DoubleType(), True), True)])
|
|
248
255
|
>>> output_data.show()
|
|
249
256
|
+----------+
|
|
250
257
|
| mult|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.17.
|
|
3
|
+
Version: 0.17.1
|
|
4
4
|
Summary: RecSys Library
|
|
5
5
|
Home-page: https://sb-ai-lab.github.io/RePlay/
|
|
6
6
|
License: Apache-2.0
|
|
@@ -20,25 +20,17 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Provides-Extra: all
|
|
21
21
|
Provides-Extra: spark
|
|
22
22
|
Provides-Extra: torch
|
|
23
|
-
Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
|
|
24
|
-
Requires-Dist: gym (>=0.26.0,<0.27.0)
|
|
25
23
|
Requires-Dist: hnswlib (==0.7.0)
|
|
26
|
-
Requires-Dist: implicit (>=0.7.0,<0.8.0)
|
|
27
|
-
Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
|
|
28
|
-
Requires-Dist: lightfm (==1.17)
|
|
29
24
|
Requires-Dist: lightning (>=2.0.2,<3.0.0) ; extra == "torch" or extra == "all"
|
|
30
|
-
Requires-Dist: llvmlite (>=0.32.1)
|
|
31
25
|
Requires-Dist: nmslib (==2.1.1)
|
|
32
|
-
Requires-Dist: numba (>=0.50)
|
|
33
26
|
Requires-Dist: numpy (>=1.20.0)
|
|
34
27
|
Requires-Dist: optuna (>=3.2.0,<3.3.0)
|
|
35
|
-
Requires-Dist: pandas (>=1.3.5
|
|
28
|
+
Requires-Dist: pandas (>=1.3.5,<=2.2.2)
|
|
36
29
|
Requires-Dist: polars (>=0.20.7,<0.21.0)
|
|
37
30
|
Requires-Dist: psutil (>=5.9.5,<5.10.0)
|
|
38
31
|
Requires-Dist: pyarrow (>=12.0.1)
|
|
39
|
-
Requires-Dist: pyspark (>=3.0,<3.
|
|
32
|
+
Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
|
|
40
33
|
Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
|
|
41
|
-
Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
|
|
42
34
|
Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
|
|
43
35
|
Requires-Dist: scipy (>=1.8.1,<1.9.0)
|
|
44
36
|
Requires-Dist: torch (>=1.8,<2.0) ; extra == "torch" or extra == "all"
|
|
@@ -1,68 +1,16 @@
|
|
|
1
|
-
replay/__init__.py,sha256=
|
|
1
|
+
replay/__init__.py,sha256=wUk_ODIXbOTEQKc4cIBpsptZ--yblkgTGRfXStYmQKI,46
|
|
2
2
|
replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
|
|
3
|
-
replay/data/dataset.py,sha256=
|
|
3
|
+
replay/data/dataset.py,sha256=cSStvCqIc6WAJNtbmsxncSpcQZ1KfULMsrmf_V0UdPw,29490
|
|
4
4
|
replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
|
|
5
5
|
replay/data/dataset_utils/dataset_label_encoder.py,sha256=TEx2zLw5rJdIz1SRBEznyVv5x_Cs7o6QQbzMk-M1LU0,9598
|
|
6
6
|
replay/data/nn/__init__.py,sha256=WxLsi4rgOuuvGYHN49xBPxP2Srhqf3NYgfBDVH-ZvBo,1122
|
|
7
|
-
replay/data/nn/schema.py,sha256=
|
|
8
|
-
replay/data/nn/sequence_tokenizer.py,sha256=
|
|
9
|
-
replay/data/nn/sequential_dataset.py,sha256=
|
|
7
|
+
replay/data/nn/schema.py,sha256=pO4N7RgmgrqfD1-2d95OTeihKHTZ-5y2BG7CX_wBFi4,16198
|
|
8
|
+
replay/data/nn/sequence_tokenizer.py,sha256=Ambrp3CMOp3JP68PiwmVh0m-_zNXiWzxxVreHkEwOyY,32592
|
|
9
|
+
replay/data/nn/sequential_dataset.py,sha256=jCWxC0Pm1eQ5p8Y6_Bmg4fSEvPaecLrqz1iaWzaICdI,11014
|
|
10
10
|
replay/data/nn/torch_sequential_dataset.py,sha256=BqrK_PtkhpsaY1zRIWGk4EgwPL31a7IWCc0hLDuwDQc,10984
|
|
11
11
|
replay/data/nn/utils.py,sha256=YKE9gkIHZDDiwv4THqOWL4PzsdOujnPuM97v79Mwq0E,2769
|
|
12
12
|
replay/data/schema.py,sha256=F_cv6sYb6l23yuX5xWnbqoJ9oSeUT2NpIM19u8Lf2jA,15606
|
|
13
13
|
replay/data/spark_schema.py,sha256=4o0Kn_fjwz2-9dBY3q46F9PL0F3E7jdVpIlX7SG3OZI,1111
|
|
14
|
-
replay/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
replay/experimental/metrics/__init__.py,sha256=W6S9YTGCezLORyTKCqL4Y_PniC1k3Bu5XWIM3WVHg2Q,2860
|
|
16
|
-
replay/experimental/metrics/base_metric.py,sha256=aYmKZ_336dRrlslBzYsgsOzmed54BNjNXsRcpzB5gyM,22648
|
|
17
|
-
replay/experimental/metrics/coverage.py,sha256=3kVBAUhIEOuD8aJ6DShH2xh_1F61dcLZb001VCkmeJk,3154
|
|
18
|
-
replay/experimental/metrics/experiment.py,sha256=Bd_XB9zbngcAwf5JLZKVPsFWQoz9pEGlPEUbkiR_MDc,7343
|
|
19
|
-
replay/experimental/metrics/hitrate.py,sha256=TfWJrUyZXabdMr4tn8zqUPGDcYy2yphVCzXmLSHCxY0,675
|
|
20
|
-
replay/experimental/metrics/map.py,sha256=S4dKiMpYR0_pu0bqioGMT0kIC1s2aojFP4rddBqMPtM,921
|
|
21
|
-
replay/experimental/metrics/mrr.py,sha256=q6I1Cndlwr716mMuYtTMu0lN8Rrp9khxhb49OM2IpV8,530
|
|
22
|
-
replay/experimental/metrics/ncis_precision.py,sha256=yrErOhBZvZdNpQPx_AXyktDJatqdWRIHNMyei0QDJtQ,1088
|
|
23
|
-
replay/experimental/metrics/ndcg.py,sha256=q3KTsyZCrfvcpEjEnR_kWVB9ZaTFRxnoNRAr2WD0TrU,1538
|
|
24
|
-
replay/experimental/metrics/precision.py,sha256=U9pD9yRGeT8uH32BTyQ-W5qsAnbFWu-pqy4XfkcXfCM,664
|
|
25
|
-
replay/experimental/metrics/recall.py,sha256=5xRPGxfbVoDFEI5E6dVlZpT4RvnDlWzaktyoqh3a8mc,774
|
|
26
|
-
replay/experimental/metrics/rocauc.py,sha256=yq4vW2_bXO8HCjREBZVrHMKeZ054LYvjJmLJTXWPfQA,1675
|
|
27
|
-
replay/experimental/metrics/surprisal.py,sha256=CK4_zed2bSMDwC7ZBCS8d8RwGEqt8bh3w3fTpjKiK6Y,3052
|
|
28
|
-
replay/experimental/metrics/unexpectedness.py,sha256=JQQXEYHtQM8nqp7X2He4E9ZYwbpdENaK8oQG7sUQT3s,2621
|
|
29
|
-
replay/experimental/models/__init__.py,sha256=R284PXgSxt-JWWwlSTLggchash0hrLfy4b2w-ySaQf4,588
|
|
30
|
-
replay/experimental/models/admm_slim.py,sha256=Oz-x0aQAnGFN9z7PB7MiKfduBasc4KQrBT0JwtYdwLY,6581
|
|
31
|
-
replay/experimental/models/base_neighbour_rec.py,sha256=pRcffr0cdRNZRVpzWb2Qv-UIsLkhbs7K1GRAmrSqPSM,7506
|
|
32
|
-
replay/experimental/models/base_rec.py,sha256=rj2r7r_mmJdzKAkg5CHG1eqJhOpUHAETPe0NwfibFjU,49606
|
|
33
|
-
replay/experimental/models/base_torch_rec.py,sha256=oDkCxVFQjIHSWKlCns6mU3ECWbQW3mQZWvBHBxJQdwc,8111
|
|
34
|
-
replay/experimental/models/cql.py,sha256=9ONDMblfxUgol5Pb2UInfSHVRbB2Ma15zAZC6valhtk,19628
|
|
35
|
-
replay/experimental/models/ddpg.py,sha256=sZrGgwj_kKeUnwwT9qooc4Cxz2oVGkNfUwUe1N7mreI,31982
|
|
36
|
-
replay/experimental/models/dt4rec/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
-
replay/experimental/models/dt4rec/dt4rec.py,sha256=ZIHYonDubStN7Gb703csy86R7Q3_1fZc4zJf98HYFe4,5895
|
|
38
|
-
replay/experimental/models/dt4rec/gpt1.py,sha256=T3buFtYyF6Fh6sW6f9dUZFcFEnQdljItbRa22CiKb0w,14044
|
|
39
|
-
replay/experimental/models/dt4rec/trainer.py,sha256=YeaJ8mnoYZqnPwm1P9qOYb8GzgFC5At-JeSDcvG2V2o,3859
|
|
40
|
-
replay/experimental/models/dt4rec/utils.py,sha256=jbCx2Xc85VtjQx-caYhJFfVuj1Wf866OAiSoZlR4q48,8201
|
|
41
|
-
replay/experimental/models/extensions/spark_custom_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
replay/experimental/models/extensions/spark_custom_models/als_extension.py,sha256=dKSVCMXWRB7IUnpEK_QNhSEuUSVcG793E8MT_AGXneY,25890
|
|
43
|
-
replay/experimental/models/implicit_wrap.py,sha256=8F-f-CaStmlNHwphu-yu8o4Aft08NKDD_SqqH0zp1Uo,4655
|
|
44
|
-
replay/experimental/models/lightfm_wrap.py,sha256=a2ctIEoZf7I0C_awiQI1lE4RGJ7ISs60znysgHRXZCw,11337
|
|
45
|
-
replay/experimental/models/mult_vae.py,sha256=FdJ-GL6Jj2l5-38edKp_jsNfwFNGPxMHXKn8cG2tGJs,11607
|
|
46
|
-
replay/experimental/models/neuromf.py,sha256=QRu--zIyOSQIp8R5Ksgiw7o0s5yOhQpuAX9YshKJs4w,14391
|
|
47
|
-
replay/experimental/models/scala_als.py,sha256=PVf0YA3ii4iRwGqpYg6nStgaauyrm9QTzLtK_4f1En0,10985
|
|
48
|
-
replay/experimental/nn/data/__init__.py,sha256=5EAF-FNd7xhkUpTq_5MyVcPXBD81mJCwYrcbhdGOWjE,48
|
|
49
|
-
replay/experimental/nn/data/schema_builder.py,sha256=5PphL9kK-tVm30nWdTjHUzqVOnTwKiU_MlxGdL5HJ8Y,1736
|
|
50
|
-
replay/experimental/preprocessing/__init__.py,sha256=uMyeyQ_GKqjLhVGwhrEk3NLhhzS0DKi5xGo3VF4WkiA,130
|
|
51
|
-
replay/experimental/preprocessing/data_preparator.py,sha256=fQ8Blo_uzA-2eC-_ViVeU26Tqj5lxLTCBoDJfEmiqUo,35968
|
|
52
|
-
replay/experimental/preprocessing/padder.py,sha256=o7S_Zk-ne_jria3QhWCKkYa6bEqhCdtvCA-R0MjOvU4,9569
|
|
53
|
-
replay/experimental/preprocessing/sequence_generator.py,sha256=E1_0uZJLv8V_n6YzRlgUWtcrHIdjNwPeBN-BMbz0e-A,9053
|
|
54
|
-
replay/experimental/scenarios/__init__.py,sha256=gWFLCkLyOmOppvbRMK7C3UMlMpcbIgiGVolSH6LPgWA,91
|
|
55
|
-
replay/experimental/scenarios/obp_wrapper/__init__.py,sha256=rsRyfsTnVNp20LkTEugwoBrV9XWbIhR8tOqec_Au6dY,450
|
|
56
|
-
replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py,sha256=vmLANYB5i1UR3uY7e-T0IBEYwPxOYHtKqhkmUvMUYhU,2548
|
|
57
|
-
replay/experimental/scenarios/obp_wrapper/replay_offline.py,sha256=A6TPBFHj_UUL0N6DHSF0-hACsH5cw2o1GMYvpPS6964,8756
|
|
58
|
-
replay/experimental/scenarios/obp_wrapper/utils.py,sha256=-ioWTb73NmHWxVxw4BdSolctqeeGIyjKtydwc45nrrk,3271
|
|
59
|
-
replay/experimental/scenarios/two_stages/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
60
|
-
replay/experimental/scenarios/two_stages/reranker.py,sha256=tJtWhbHRNV4sJZ9RZzqIfylTplKh9QVwTIBhEGGnXq8,4244
|
|
61
|
-
replay/experimental/scenarios/two_stages/two_stages_scenario.py,sha256=ZgflnQ6xuxDFphdKX6Q0jtXidHS7c2YvDaccoaL78Qo,29846
|
|
62
|
-
replay/experimental/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
63
|
-
replay/experimental/utils/logger.py,sha256=UwLowaeOG17sDEe32LiZel8MnjSTzeW7J3uLG1iwLuA,639
|
|
64
|
-
replay/experimental/utils/model_handler.py,sha256=0ksSm5bJ1bL32VV5HI-KPe0a1EAzzOhMtmSYaM_zRrE,6271
|
|
65
|
-
replay/experimental/utils/session_handler.py,sha256=076TLpTOcnh13BznNTtJW6Zhrqvm9Ee1mlpP5YMD4No,1313
|
|
66
14
|
replay/metrics/__init__.py,sha256=KDkxVnKa4ks9K9GmlrdTx1pkIl-MAmm78ZASsp2ZndE,2812
|
|
67
15
|
replay/metrics/base_metric.py,sha256=uleW5vLrdA3iRx72tFyW0cxe6ne_ugQ1XaY_ZTcnAOo,15960
|
|
68
16
|
replay/metrics/categorical_diversity.py,sha256=OYsF-Ng-WrF9CC-sKgQKngrA779NO8MtgRvvAyC8MXM,10781
|
|
@@ -148,14 +96,14 @@ replay/optimization/__init__.py,sha256=az6U10rF7X6rPRUUPwLyiM1WFNJ_6kl0imA5xLVWF
|
|
|
148
96
|
replay/optimization/optuna_objective.py,sha256=Z-8X0_FT3BicVWj0UhxoLrvZAck3Dhn7jHDGo0i0hxA,7653
|
|
149
97
|
replay/preprocessing/__init__.py,sha256=TtBysFqYeDy4kZAEnWEaNSwPvbffYdfMkEs71YG51fM,411
|
|
150
98
|
replay/preprocessing/converter.py,sha256=DczqsVLrwFi6EFhK2HR8rGiIxGCwXeY7QNgWorjA41g,4390
|
|
151
|
-
replay/preprocessing/filters.py,sha256=
|
|
99
|
+
replay/preprocessing/filters.py,sha256=wsXWQoZ-2aAecunLkaTxeLWi5ow4e3FAGcElx0iNx0w,41669
|
|
152
100
|
replay/preprocessing/history_based_fp.py,sha256=tfgKJPKm53LSNqM6VmMXYsVrRDc-rP1Tbzn8s3mbziQ,18751
|
|
153
101
|
replay/preprocessing/label_encoder.py,sha256=MLBavPD-dB644as0E9ZJSE9-8QxGCB_IHek1w3xtqDI,27040
|
|
154
102
|
replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
|
|
155
103
|
replay/scenarios/__init__.py,sha256=kw2wRkPPinw0IBA20D83XQ3xeSudk3KuYAAA1Wdr8xY,93
|
|
156
104
|
replay/scenarios/fallback.py,sha256=EeBmIR-5igzKR2m55bQRFyhxTkpJez6ZkCW449n8hWs,7130
|
|
157
105
|
replay/splitters/__init__.py,sha256=DnqVMelrzLwR8fGQgcWN_8FipGs8T4XGSPOMW-L_x2g,454
|
|
158
|
-
replay/splitters/base_splitter.py,sha256=
|
|
106
|
+
replay/splitters/base_splitter.py,sha256=hj9_GYDWllzv3XnxN6WHu1JKRRVjXo77vZEOLbF9v-s,7761
|
|
159
107
|
replay/splitters/cold_user_random_splitter.py,sha256=gVwBVdn_0IOaLGT_UzJoS9AMaPhelZy-FpC5JQS1PhA,4136
|
|
160
108
|
replay/splitters/k_folds.py,sha256=WH02_DP18A2ae893ysonmfLPB56_i1ETllTAwaCYekg,6218
|
|
161
109
|
replay/splitters/last_n_splitter.py,sha256=r9kdq2JPi508C9ywjwc68an-iq27KsigMfHWLz0YohE,15346
|
|
@@ -165,16 +113,15 @@ replay/splitters/ratio_splitter.py,sha256=8zvuCn16Icc4ntQPKXJ5ArAWuJzCZ9NHZtgWct
|
|
|
165
113
|
replay/splitters/time_splitter.py,sha256=iXhuafjBx7dWyJSy-TEVy1IUQBwMpA1gAiF4-GtRe2g,9031
|
|
166
114
|
replay/splitters/two_stage_splitter.py,sha256=PWozxjjgjrVzdz6Sm9dcDTeH0bOA24reFzkk_N_TgbQ,17734
|
|
167
115
|
replay/utils/__init__.py,sha256=vDJgOWq81fbBs-QO4ZDpdqR4KDyO1kMOOxBRi-5Gp7E,253
|
|
168
|
-
replay/utils/common.py,sha256=
|
|
116
|
+
replay/utils/common.py,sha256=s4Pro3QCkPeVBsj-s0vrbhd_pkJD-_-2M_sIguxGzQQ,5411
|
|
169
117
|
replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
|
|
170
118
|
replay/utils/distributions.py,sha256=kGGq2KzQZ-yhTuw_vtOsKFXVpXUOQ2l4aIFBcaDufZ8,1202
|
|
171
119
|
replay/utils/model_handler.py,sha256=V-mHDh8_UexjVSsMBBRA9yrjS_5MPHwYOwv_UrI-Zfs,6466
|
|
172
120
|
replay/utils/session_handler.py,sha256=ijTvDSNAe1D9R1e-dhtd-r80tFNiIBsFdWZLgw-gLEo,5153
|
|
173
|
-
replay/utils/spark_utils.py,sha256=
|
|
121
|
+
replay/utils/spark_utils.py,sha256=k5lUFM2C9QZKQON3dqhgfswyUF4tsgJOn0U2wCKimqM,26901
|
|
174
122
|
replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
|
|
175
123
|
replay/utils/types.py,sha256=5sw0A7NG4ZgQKdWORnBy0wBZ5F98sP_Ju8SKQ6zbDS4,651
|
|
176
|
-
replay_rec-0.17.
|
|
177
|
-
replay_rec-0.17.
|
|
178
|
-
replay_rec-0.17.
|
|
179
|
-
replay_rec-0.17.
|
|
180
|
-
replay_rec-0.17.0rc0.dist-info/RECORD,,
|
|
124
|
+
replay_rec-0.17.1.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
|
|
125
|
+
replay_rec-0.17.1.dist-info/METADATA,sha256=IDkSzO_PcQgyU4Xqnpi0WTHkqyVS0t3vNvisONZaBLg,10589
|
|
126
|
+
replay_rec-0.17.1.dist-info/WHEEL,sha256=Zb28QaM1gQi8f4VCBhsUklF61CTlNYfs9YAZn-TOGFk,88
|
|
127
|
+
replay_rec-0.17.1.dist-info/RECORD,,
|
replay/experimental/__init__.py
DELETED
|
File without changes
|
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Most metrics require dataframe with recommendations
|
|
3
|
-
and dataframe with ground truth values —
|
|
4
|
-
which objects each user interacted with.
|
|
5
|
-
|
|
6
|
-
- recommendations (Union[pandas.DataFrame, spark.DataFrame]):
|
|
7
|
-
predictions of a recommender system,
|
|
8
|
-
DataFrame with columns ``[user_id, item_id, relevance]``
|
|
9
|
-
- ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
|
|
10
|
-
test data, DataFrame with columns
|
|
11
|
-
``[user_id, item_id, timestamp, relevance]``
|
|
12
|
-
|
|
13
|
-
Metric is calculated for all users, presented in ``ground_truth``
|
|
14
|
-
for accurate metric calculation in case when the recommender system generated
|
|
15
|
-
recommendation not for all users. It is assumed, that all users,
|
|
16
|
-
we want to calculate metric for, have positive interactions.
|
|
17
|
-
|
|
18
|
-
But if we have users, who observed the recommendations, but have not responded,
|
|
19
|
-
those users will be ignored and metric will be overestimated.
|
|
20
|
-
For such case we propose additional optional parameter ``ground_truth_users``,
|
|
21
|
-
the dataframe with all users, which should be considered during the metric calculation.
|
|
22
|
-
|
|
23
|
-
- ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
|
|
24
|
-
full list of users to calculate metric for, DataFrame with ``user_id`` column
|
|
25
|
-
|
|
26
|
-
Every metric is calculated using top ``K`` items for each user.
|
|
27
|
-
It is also possible to calculate metrics
|
|
28
|
-
using multiple values for ``K`` simultaneously.
|
|
29
|
-
In this case the result will be a dictionary and not a number.
|
|
30
|
-
|
|
31
|
-
Make sure your recommendations do not contain user-item duplicates
|
|
32
|
-
as duplicates could lead to the wrong calculation results.
|
|
33
|
-
|
|
34
|
-
- k (Union[Iterable[int], int]):
|
|
35
|
-
a single number or a list, specifying the
|
|
36
|
-
truncation length for recommendation list for each user
|
|
37
|
-
|
|
38
|
-
By default, metrics are averaged by users,
|
|
39
|
-
but you can alternatively use method ``metric.median``.
|
|
40
|
-
Also, you can get the lower bound
|
|
41
|
-
of ``conf_interval`` for a given ``alpha``.
|
|
42
|
-
|
|
43
|
-
Diversity metrics require extra parameters on initialization stage,
|
|
44
|
-
but do not use ``ground_truth`` parameter.
|
|
45
|
-
|
|
46
|
-
For each metric, a formula for its calculation is given, because this is
|
|
47
|
-
important for the correct comparison of algorithms, as mentioned in our
|
|
48
|
-
`article <https://arxiv.org/abs/2206.12858>`_.
|
|
49
|
-
"""
|
|
50
|
-
from replay.experimental.metrics.base_metric import Metric, NCISMetric
|
|
51
|
-
from replay.experimental.metrics.coverage import Coverage
|
|
52
|
-
from replay.experimental.metrics.hitrate import HitRate
|
|
53
|
-
from replay.experimental.metrics.map import MAP
|
|
54
|
-
from replay.experimental.metrics.mrr import MRR
|
|
55
|
-
from replay.experimental.metrics.ncis_precision import NCISPrecision
|
|
56
|
-
from replay.experimental.metrics.ndcg import NDCG
|
|
57
|
-
from replay.experimental.metrics.precision import Precision
|
|
58
|
-
from replay.experimental.metrics.recall import Recall
|
|
59
|
-
from replay.experimental.metrics.rocauc import RocAuc
|
|
60
|
-
from replay.experimental.metrics.surprisal import Surprisal
|
|
61
|
-
from replay.experimental.metrics.unexpectedness import Unexpectedness
|