replay-rec 0.17.0rc0__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +246 -20
  3. replay/data/nn/schema.py +42 -0
  4. replay/data/nn/sequence_tokenizer.py +17 -47
  5. replay/data/nn/sequential_dataset.py +76 -2
  6. replay/preprocessing/filters.py +169 -4
  7. replay/splitters/base_splitter.py +1 -1
  8. replay/utils/common.py +107 -5
  9. replay/utils/spark_utils.py +13 -6
  10. {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/METADATA +3 -11
  11. {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/RECORD +13 -66
  12. {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/WHEEL +1 -1
  13. replay/experimental/__init__.py +0 -0
  14. replay/experimental/metrics/__init__.py +0 -61
  15. replay/experimental/metrics/base_metric.py +0 -601
  16. replay/experimental/metrics/coverage.py +0 -97
  17. replay/experimental/metrics/experiment.py +0 -175
  18. replay/experimental/metrics/hitrate.py +0 -26
  19. replay/experimental/metrics/map.py +0 -30
  20. replay/experimental/metrics/mrr.py +0 -18
  21. replay/experimental/metrics/ncis_precision.py +0 -31
  22. replay/experimental/metrics/ndcg.py +0 -49
  23. replay/experimental/metrics/precision.py +0 -22
  24. replay/experimental/metrics/recall.py +0 -25
  25. replay/experimental/metrics/rocauc.py +0 -49
  26. replay/experimental/metrics/surprisal.py +0 -90
  27. replay/experimental/metrics/unexpectedness.py +0 -76
  28. replay/experimental/models/__init__.py +0 -10
  29. replay/experimental/models/admm_slim.py +0 -205
  30. replay/experimental/models/base_neighbour_rec.py +0 -204
  31. replay/experimental/models/base_rec.py +0 -1271
  32. replay/experimental/models/base_torch_rec.py +0 -234
  33. replay/experimental/models/cql.py +0 -452
  34. replay/experimental/models/ddpg.py +0 -921
  35. replay/experimental/models/dt4rec/__init__.py +0 -0
  36. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  37. replay/experimental/models/dt4rec/gpt1.py +0 -401
  38. replay/experimental/models/dt4rec/trainer.py +0 -127
  39. replay/experimental/models/dt4rec/utils.py +0 -265
  40. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  41. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  42. replay/experimental/models/implicit_wrap.py +0 -131
  43. replay/experimental/models/lightfm_wrap.py +0 -302
  44. replay/experimental/models/mult_vae.py +0 -331
  45. replay/experimental/models/neuromf.py +0 -405
  46. replay/experimental/models/scala_als.py +0 -296
  47. replay/experimental/nn/data/__init__.py +0 -1
  48. replay/experimental/nn/data/schema_builder.py +0 -55
  49. replay/experimental/preprocessing/__init__.py +0 -3
  50. replay/experimental/preprocessing/data_preparator.py +0 -838
  51. replay/experimental/preprocessing/padder.py +0 -229
  52. replay/experimental/preprocessing/sequence_generator.py +0 -208
  53. replay/experimental/scenarios/__init__.py +0 -1
  54. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  55. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  56. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
  57. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  58. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  59. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  60. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  61. replay/experimental/utils/__init__.py +0 -0
  62. replay/experimental/utils/logger.py +0 -24
  63. replay/experimental/utils/model_handler.py +0 -181
  64. replay/experimental/utils/session_handler.py +0 -44
  65. replay_rec-0.17.0rc0.dist-info/NOTICE +0 -41
  66. {replay_rec-0.17.0rc0.dist-info → replay_rec-0.17.1.dist-info}/LICENSE +0 -0
@@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
5
5
  from datetime import datetime, timedelta
6
6
  from typing import Callable, Optional, Tuple, Union
7
7
 
8
+ import numpy as np
9
+ import pandas as pd
8
10
  import polars as pl
9
11
 
10
12
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
@@ -357,7 +359,7 @@ class NumInteractionsFilter(_BaseFilter):
357
359
  ... "2020-02-01", "2020-01-01 00:04:15",
358
360
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
359
361
  ... )
360
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
362
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
361
363
  >>> log_sp = convert2spark(log_pd)
362
364
  >>> log_sp.show()
363
365
  +-------+-------+------+-------------------+
@@ -499,7 +501,7 @@ class EntityDaysFilter(_BaseFilter):
499
501
  ... "2020-02-01", "2020-01-01 00:04:15",
500
502
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
501
503
  ... )
502
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
504
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
503
505
  >>> log_sp = convert2spark(log_pd)
504
506
  >>> log_sp.orderBy('user_id', 'item_id').show()
505
507
  +-------+-------+------+-------------------+
@@ -638,7 +640,7 @@ class GlobalDaysFilter(_BaseFilter):
638
640
  ... "2020-02-01", "2020-01-01 00:04:15",
639
641
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
640
642
  ... )
641
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
643
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
642
644
  >>> log_sp = convert2spark(log_pd)
643
645
  >>> log_sp.show()
644
646
  +-------+-------+------+-------------------+
@@ -740,7 +742,7 @@ class TimePeriodFilter(_BaseFilter):
740
742
  ... "2020-02-01", "2020-01-01 00:04:15",
741
743
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
742
744
  ... )
743
- >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"])
745
+ >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
744
746
  >>> log_sp = convert2spark(log_pd)
745
747
  >>> log_sp.show()
746
748
  +-------+-------+------+-------------------+
@@ -823,3 +825,166 @@ class TimePeriodFilter(_BaseFilter):
823
825
  return interactions.filter(
824
826
  pl.col(self.timestamp_column).is_between(self.start_date, self.end_date, closed="left")
825
827
  )
828
+
829
+
830
+ class QuantileItemsFilter(_BaseFilter):
831
+ """
832
+ Filter is aimed on undersampling the interactions dataset.
833
+
834
+ Filter algorithm performs undersampling by removing `items_proportion` of interactions
835
+ for each items counts that exceeds the `alpha_quantile` value in distribution. Filter firstly
836
+ removes popular items (items that have most interactions). Filter also keeps the original
837
+ relation of items popularity among each other by removing interactions only in range of
838
+ current item count and quantile count (specified by `alpha_quantile`).
839
+
840
+ >>> import pandas as pd
841
+ >>> from replay.utils.spark_utils import convert2spark
842
+ >>> log_pd = pd.DataFrame({
843
+ ... "user_id": [0, 0, 1, 2, 2, 2, 2],
844
+ ... "item_id": [0, 2, 1, 1, 2, 2, 2]
845
+ ... })
846
+ >>> log_spark = convert2spark(log_pd)
847
+ >>> log_spark.show()
848
+ +-------+-------+
849
+ |user_id|item_id|
850
+ +-------+-------+
851
+ | 0| 0|
852
+ | 0| 2|
853
+ | 1| 1|
854
+ | 2| 1|
855
+ | 2| 2|
856
+ | 2| 2|
857
+ | 2| 2|
858
+ +-------+-------+
859
+ <BLANKLINE>
860
+
861
+ >>> QuantileItemsFilter(query_column="user_id").transform(log_spark).show()
862
+ +-------+-------+
863
+ |user_id|item_id|
864
+ +-------+-------+
865
+ | 0| 0|
866
+ | 1| 1|
867
+ | 2| 1|
868
+ | 2| 2|
869
+ | 2| 2|
870
+ | 0| 2|
871
+ +-------+-------+
872
+ <BLANKLINE>
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ alpha_quantile: float = 0.99,
878
+ items_proportion: float = 0.5,
879
+ query_column: str = "query_id",
880
+ item_column: str = "item_id",
881
+ ) -> None:
882
+ """
883
+ :param alpha_quantile: Quantile value of items counts distribution to keep unchanged.
884
+ Every items count that exceeds this value will be undersampled.
885
+ Default: ``0.99``.
886
+ :param items_proportion: proportion of items counts to remove for items that
887
+ exceeds `alpha_quantile` value in range of current item count and quantile count
888
+ to make sure we keep original relation between items unchanged.
889
+ Default: ``0.5``.
890
+ :param query_column: query column name.
891
+ Default: ``query_id``.
892
+ :param item_column: item column name.
893
+ Default: ``item_id``.
894
+ """
895
+ if not 0 < alpha_quantile < 1:
896
+ msg = "`alpha_quantile` value must be in (0, 1)"
897
+ raise ValueError(msg)
898
+ if not 0 < items_proportion < 1:
899
+ msg = "`items_proportion` value must be in (0, 1)"
900
+ raise ValueError(msg)
901
+
902
+ self.alpha_quantile = alpha_quantile
903
+ self.items_proportion = items_proportion
904
+ self.query_column = query_column
905
+ self.item_column = item_column
906
+
907
+ def _filter_pandas(self, df: pd.DataFrame):
908
+ items_distribution = df.groupby(self.item_column).size().reset_index().rename(columns={0: "counts"})
909
+ users_distribution = df.groupby(self.query_column).size().reset_index().rename(columns={0: "counts"})
910
+ count_threshold = items_distribution.loc[:, "counts"].quantile(self.alpha_quantile, interpolation="midpoint")
911
+ df_with_counts = df.merge(items_distribution, how="left", on=self.item_column).merge(
912
+ users_distribution, how="left", on=self.query_column, suffixes=["_items", "_users"]
913
+ )
914
+ long_tail = df_with_counts.loc[df_with_counts["counts_items"] <= count_threshold]
915
+ short_tail = df_with_counts.loc[df_with_counts["counts_items"] > count_threshold]
916
+ short_tail["num_items_to_delete"] = self.items_proportion * (
917
+ short_tail["counts_items"] - long_tail["counts_items"].max()
918
+ )
919
+ short_tail["num_items_to_delete"] = short_tail["num_items_to_delete"].astype("int")
920
+ short_tail = short_tail.sort_values("counts_users", ascending=False)
921
+
922
+ def get_mask(x):
923
+ mask = np.ones_like(x)
924
+ threshold = x.iloc[0]
925
+ mask[:threshold] = 0
926
+ return mask
927
+
928
+ mask = short_tail.groupby(self.item_column)["num_items_to_delete"].transform(get_mask).astype(bool)
929
+ return pd.concat([long_tail[df.columns], short_tail.loc[mask][df.columns]])
930
+
931
+ def _filter_polars(self, df: pl.DataFrame):
932
+ items_distribution = df.group_by(self.item_column).len()
933
+ users_distribution = df.group_by(self.query_column).len()
934
+ count_threshold = items_distribution.select("len").quantile(self.alpha_quantile, "midpoint")["len"][0]
935
+ df_with_counts = (
936
+ df.join(items_distribution, how="left", on=self.item_column).join(
937
+ users_distribution, how="left", on=self.query_column
938
+ )
939
+ ).rename({"len": "counts_items", "len_right": "counts_users"})
940
+ long_tail = df_with_counts.filter(pl.col("counts_items") <= count_threshold)
941
+ short_tail = df_with_counts.filter(pl.col("counts_items") > count_threshold)
942
+ max_long_tail_count = long_tail["counts_items"].max()
943
+ items_to_delete = (
944
+ short_tail.select(
945
+ self.query_column,
946
+ self.item_column,
947
+ self.items_proportion * (pl.col("counts_items") - max_long_tail_count),
948
+ )
949
+ .with_columns(pl.col("literal").cast(pl.Int64).alias("num_items_to_delete"))
950
+ .select(self.item_column, "num_items_to_delete")
951
+ .unique(maintain_order=True)
952
+ )
953
+ short_tail = short_tail.join(items_to_delete, how="left", on=self.item_column).sort(
954
+ "counts_users", descending=True
955
+ )
956
+ short_tail = short_tail.with_columns(index=pl.int_range(short_tail.shape[0]))
957
+ grouped = short_tail.group_by(self.item_column, maintain_order=True).agg(
958
+ pl.col("index"), pl.col("num_items_to_delete")
959
+ )
960
+ grouped = grouped.with_columns(
961
+ pl.col("num_items_to_delete").list.get(0),
962
+ (pl.col("index").list.len() - pl.col("num_items_to_delete").list.get(0)).alias("tail"),
963
+ )
964
+ grouped = grouped.with_columns(pl.col("index").list.tail(pl.col("tail")))
965
+ grouped = grouped.explode("index").select("index")
966
+ short_tail = grouped.join(short_tail, how="left", on="index")
967
+ return pl.concat([long_tail.select(df.columns), short_tail.select(df.columns)])
968
+
969
+ def _filter_spark(self, df: SparkDataFrame):
970
+ items_distribution = df.groupBy(self.item_column).agg(sf.count(self.query_column).alias("counts_items"))
971
+ users_distribution = df.groupBy(self.query_column).agg(sf.count(self.item_column).alias("counts_users"))
972
+ count_threshold = items_distribution.toPandas().loc[:, "counts_items"].quantile(self.alpha_quantile, "midpoint")
973
+ df_with_counts = df.join(items_distribution, on=self.item_column).join(users_distribution, on=self.query_column)
974
+ long_tail = df_with_counts.filter(sf.col("counts_items") <= count_threshold)
975
+ short_tail = df_with_counts.filter(sf.col("counts_items") > count_threshold)
976
+ max_long_tail_count = long_tail.agg({"counts_items": "max"}).collect()[0][0]
977
+ items_to_delete = (
978
+ short_tail.withColumn(
979
+ "num_items_to_delete",
980
+ (self.items_proportion * (sf.col("counts_items") - max_long_tail_count)).cast("int"),
981
+ )
982
+ .select(self.item_column, "num_items_to_delete")
983
+ .distinct()
984
+ )
985
+ short_tail = short_tail.join(items_to_delete, on=self.item_column, how="left")
986
+ short_tail = short_tail.withColumn(
987
+ "index", sf.row_number().over(Window.partitionBy(self.item_column).orderBy(sf.col("counts_users").desc()))
988
+ )
989
+ short_tail = short_tail.filter(sf.col("index") > sf.col("num_items_to_delete"))
990
+ return long_tail.select(df.columns).union(short_tail.select(df.columns))
@@ -85,7 +85,7 @@ class Splitter(ABC):
85
85
  json.dump(splitter_dict, file)
86
86
 
87
87
  @classmethod
88
- def load(cls, path: str) -> "Splitter":
88
+ def load(cls, path: str, **kwargs) -> "Splitter":
89
89
  """
90
90
  Method for loading splitter from `.replay` directory.
91
91
  """
replay/utils/common.py CHANGED
@@ -1,7 +1,12 @@
1
+ import functools
2
+ import inspect
1
3
  import json
2
4
  from pathlib import Path
3
- from typing import Union
5
+ from typing import Any, Callable, Union
4
6
 
7
+ from polars import from_pandas as pl_from_pandas
8
+
9
+ from replay.data.dataset import Dataset
5
10
  from replay.splitters import (
6
11
  ColdUserRandomSplitter,
7
12
  KFolds,
@@ -12,7 +17,16 @@ from replay.splitters import (
12
17
  TimeSplitter,
13
18
  TwoStageSplitter,
14
19
  )
15
- from replay.utils import TORCH_AVAILABLE
20
+ from replay.utils import (
21
+ TORCH_AVAILABLE,
22
+ PandasDataFrame,
23
+ PolarsDataFrame,
24
+ SparkDataFrame,
25
+ )
26
+ from replay.utils.spark_utils import (
27
+ convert2spark as pandas_to_spark,
28
+ spark_to_pandas,
29
+ )
16
30
 
17
31
  SavableObject = Union[
18
32
  ColdUserRandomSplitter,
@@ -23,10 +37,11 @@ SavableObject = Union[
23
37
  RatioSplitter,
24
38
  TimeSplitter,
25
39
  TwoStageSplitter,
40
+ Dataset,
26
41
  ]
27
42
 
28
43
  if TORCH_AVAILABLE:
29
- from replay.data.nn import SequenceTokenizer
44
+ from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer
30
45
 
31
46
  SavableObject = Union[
32
47
  ColdUserRandomSplitter,
@@ -38,6 +53,8 @@ if TORCH_AVAILABLE:
38
53
  TimeSplitter,
39
54
  TwoStageSplitter,
40
55
  SequenceTokenizer,
56
+ PandasSequentialDataset,
57
+ PolarsSequentialDataset,
41
58
  ]
42
59
 
43
60
 
@@ -50,7 +67,7 @@ def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
50
67
  obj.save(path)
51
68
 
52
69
 
53
- def load_from_replay(path: Union[str, Path]) -> SavableObject:
70
+ def load_from_replay(path: Union[str, Path], **kwargs) -> SavableObject:
54
71
  """
55
72
  General function to load RePlay models, splitters and tokenizer.
56
73
 
@@ -60,6 +77,91 @@ def load_from_replay(path: Union[str, Path]) -> SavableObject:
60
77
  with open(path / "init_args.json", "r") as file:
61
78
  class_name = json.loads(file.read())["_class_name"]
62
79
  obj_type = globals()[class_name]
63
- obj = obj_type.load(path)
80
+ obj = obj_type.load(path, **kwargs)
64
81
 
65
82
  return obj
83
+
84
+
85
+ def _check_if_dataframe(var: Any):
86
+ if not isinstance(var, (SparkDataFrame, PolarsDataFrame, PandasDataFrame)):
87
+ msg = f"Object of type {type(var)} is not a dataframe of known type (can be pandas|spark|polars)"
88
+ raise ValueError(msg)
89
+
90
+
91
+ def check_if_dataframe(*args_to_check: str) -> Callable[..., Any]:
92
+ def decorator_func(func: Callable[..., Any]) -> Callable[..., Any]:
93
+ @functools.wraps(func)
94
+ def wrap_func(*args: Any, **kwargs: Any) -> Any:
95
+ extended_kwargs = {}
96
+ extended_kwargs.update(kwargs)
97
+ extended_kwargs.update(dict(zip(inspect.signature(func).parameters.keys(), args)))
98
+ # add default param values to dict with arguments
99
+ extended_kwargs.update(
100
+ {
101
+ x.name: x.default
102
+ for x in inspect.signature(func).parameters.values()
103
+ if x.name not in extended_kwargs and x.default is not x.empty
104
+ }
105
+ )
106
+ vals_to_check = [extended_kwargs[_arg] for _arg in args_to_check]
107
+ for val in vals_to_check:
108
+ _check_if_dataframe(val)
109
+ return func(*args, **kwargs)
110
+
111
+ return wrap_func
112
+
113
+ return decorator_func
114
+
115
+
116
+ @check_if_dataframe("data")
117
+ def convert2pandas(
118
+ data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
119
+ ) -> PandasDataFrame:
120
+ """
121
+ Convert the spark|polars DataFrame to a pandas.DataFrame.
122
+ Returns unchanged dataframe if the input is already of type pandas.DataFrame.
123
+
124
+ :param data: The dataframe to convert. Can be polars|spark|pandas DataFrame.
125
+ :param allow_collect_to_master: If set to False (default) raises a warning
126
+ about collecting parallelized data to the master node.
127
+ """
128
+ if isinstance(data, PandasDataFrame):
129
+ return data
130
+ if isinstance(data, PolarsDataFrame):
131
+ return data.to_pandas()
132
+ if isinstance(data, SparkDataFrame):
133
+ return spark_to_pandas(data, allow_collect_to_master, from_constructor=False)
134
+
135
+
136
+ @check_if_dataframe("data")
137
+ def convert2polars(
138
+ data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame], allow_collect_to_master: bool = False
139
+ ) -> PolarsDataFrame:
140
+ """
141
+ Convert the spark|pandas DataFrame to a polars.DataFrame.
142
+ Returns unchanged dataframe if the input is already of type polars.DataFrame.
143
+
144
+ :param data: The dataframe to convert. Can be spark|pandas|polars DataFrame.
145
+ :param allow_collect_to_master: If set to False (default) raises a warning
146
+ about collecting parallelized data to the master node.
147
+ """
148
+ if isinstance(data, PandasDataFrame):
149
+ return pl_from_pandas(data)
150
+ if isinstance(data, PolarsDataFrame):
151
+ return data
152
+ if isinstance(data, SparkDataFrame):
153
+ return pl_from_pandas(spark_to_pandas(data, allow_collect_to_master, from_constructor=False))
154
+
155
+
156
+ @check_if_dataframe("data")
157
+ def convert2spark(data: Union[SparkDataFrame, PolarsDataFrame, PandasDataFrame]) -> SparkDataFrame:
158
+ """
159
+ Convert the pandas|polars DataFrame to a pysaprk.sql.DataFrame.
160
+ Returns unchanged dataframe if the input is already of type pysaprk.sql.DataFrame.
161
+
162
+ :param data: The dataframe to convert. Can be pandas|polars|spark Datarame.
163
+ """
164
+ if isinstance(data, (PandasDataFrame, SparkDataFrame)):
165
+ return pandas_to_spark(data)
166
+ if isinstance(data, PolarsDataFrame):
167
+ return pandas_to_spark(data.to_pandas())
@@ -33,7 +33,9 @@ class SparkCollectToMasterWarning(Warning): # pragma: no cover
33
33
  """
34
34
 
35
35
 
36
- def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False) -> pd.DataFrame: # pragma: no cover
36
+ def spark_to_pandas(
37
+ data: SparkDataFrame, allow_collect_to_master: bool = False, from_constructor: bool = True
38
+ ) -> pd.DataFrame: # pragma: no cover
37
39
  """
38
40
  Convert Spark DataFrame to Pandas DataFrame.
39
41
 
@@ -42,10 +44,15 @@ def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False)
42
44
 
43
45
  :returns: Converted Pandas DataFrame.
44
46
  """
47
+ warn_msg = "Spark Data Frame is collected to master node, this may lead to OOM exception for larger dataset. "
48
+ if from_constructor:
49
+ _msg = "To remove this warning set allow_collect_to_master=True in the recommender constructor."
50
+ else:
51
+ _msg = "To remove this warning set allow_collect_to_master=True."
52
+ warn_msg += _msg
45
53
  if not allow_collect_to_master:
46
54
  warnings.warn(
47
- "Spark Data Frame is collected to master node, this may lead to OOM exception for larger dataset. "
48
- "To remove this warning set allow_collect_to_master=True in the recommender constructor.",
55
+ warn_msg,
49
56
  SparkCollectToMasterWarning,
50
57
  )
51
58
  return data.toPandas()
@@ -169,7 +176,7 @@ if PYSPARK_AVAILABLE:
169
176
  <BLANKLINE>
170
177
  >>> output_data = input_data.select(vector_dot("one", "two").alias("dot"))
171
178
  >>> output_data.schema
172
- StructType(List(StructField(dot,DoubleType,true)))
179
+ StructType([StructField('dot', DoubleType(), True)])
173
180
  >>> output_data.show()
174
181
  +----+
175
182
  | dot|
@@ -207,7 +214,7 @@ if PYSPARK_AVAILABLE:
207
214
  <BLANKLINE>
208
215
  >>> output_data = input_data.select(vector_mult("one", "two").alias("mult"))
209
216
  >>> output_data.schema
210
- StructType(List(StructField(mult,VectorUDT,true)))
217
+ StructType([StructField('mult', VectorUDT(), True)])
211
218
  >>> output_data.show()
212
219
  +---------+
213
220
  | mult|
@@ -244,7 +251,7 @@ if PYSPARK_AVAILABLE:
244
251
  <BLANKLINE>
245
252
  >>> output_data = input_data.select(array_mult("one", "two").alias("mult"))
246
253
  >>> output_data.schema
247
- StructType(List(StructField(mult,ArrayType(DoubleType,true),true)))
254
+ StructType([StructField('mult', ArrayType(DoubleType(), True), True)])
248
255
  >>> output_data.show()
249
256
  +----------+
250
257
  | mult|
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: replay-rec
3
- Version: 0.17.0rc0
3
+ Version: 0.17.1
4
4
  Summary: RecSys Library
5
5
  Home-page: https://sb-ai-lab.github.io/RePlay/
6
6
  License: Apache-2.0
@@ -20,25 +20,17 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Provides-Extra: all
21
21
  Provides-Extra: spark
22
22
  Provides-Extra: torch
23
- Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
24
- Requires-Dist: gym (>=0.26.0,<0.27.0)
25
23
  Requires-Dist: hnswlib (==0.7.0)
26
- Requires-Dist: implicit (>=0.7.0,<0.8.0)
27
- Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
28
- Requires-Dist: lightfm (==1.17)
29
24
  Requires-Dist: lightning (>=2.0.2,<3.0.0) ; extra == "torch" or extra == "all"
30
- Requires-Dist: llvmlite (>=0.32.1)
31
25
  Requires-Dist: nmslib (==2.1.1)
32
- Requires-Dist: numba (>=0.50)
33
26
  Requires-Dist: numpy (>=1.20.0)
34
27
  Requires-Dist: optuna (>=3.2.0,<3.3.0)
35
- Requires-Dist: pandas (>=1.3.5,<2.0.0)
28
+ Requires-Dist: pandas (>=1.3.5,<=2.2.2)
36
29
  Requires-Dist: polars (>=0.20.7,<0.21.0)
37
30
  Requires-Dist: psutil (>=5.9.5,<5.10.0)
38
31
  Requires-Dist: pyarrow (>=12.0.1)
39
- Requires-Dist: pyspark (>=3.0,<3.3) ; extra == "spark" or extra == "all"
32
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
40
33
  Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
41
- Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
42
34
  Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
43
35
  Requires-Dist: scipy (>=1.8.1,<1.9.0)
44
36
  Requires-Dist: torch (>=1.8,<2.0) ; extra == "torch" or extra == "all"
@@ -1,68 +1,16 @@
1
- replay/__init__.py,sha256=y6Ms_dBdP_0tx6CPUF9QV0jrhb-ogRReafA6edgal_E,54
1
+ replay/__init__.py,sha256=wUk_ODIXbOTEQKc4cIBpsptZ--yblkgTGRfXStYmQKI,46
2
2
  replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
3
- replay/data/dataset.py,sha256=ysMTNfx8I2hI9fSugtt3IPhenmutgzQMw-8VcM3oUJk,21299
3
+ replay/data/dataset.py,sha256=cSStvCqIc6WAJNtbmsxncSpcQZ1KfULMsrmf_V0UdPw,29490
4
4
  replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
5
5
  replay/data/dataset_utils/dataset_label_encoder.py,sha256=TEx2zLw5rJdIz1SRBEznyVv5x_Cs7o6QQbzMk-M1LU0,9598
6
6
  replay/data/nn/__init__.py,sha256=WxLsi4rgOuuvGYHN49xBPxP2Srhqf3NYgfBDVH-ZvBo,1122
7
- replay/data/nn/schema.py,sha256=BYU65vLqPDl69OE-rReh59fiQK0ERfs1xbBLWCiIJnw,14258
8
- replay/data/nn/sequence_tokenizer.py,sha256=dXD8l7IfK1dod8p--I6BhvE9af3iUOfpaoW2QBU9hTs,34133
9
- replay/data/nn/sequential_dataset.py,sha256=fqlyBAzDmpH332S-LoMP9PoRYMtgZpxG6Qdahmk5GtE,7840
7
+ replay/data/nn/schema.py,sha256=pO4N7RgmgrqfD1-2d95OTeihKHTZ-5y2BG7CX_wBFi4,16198
8
+ replay/data/nn/sequence_tokenizer.py,sha256=Ambrp3CMOp3JP68PiwmVh0m-_zNXiWzxxVreHkEwOyY,32592
9
+ replay/data/nn/sequential_dataset.py,sha256=jCWxC0Pm1eQ5p8Y6_Bmg4fSEvPaecLrqz1iaWzaICdI,11014
10
10
  replay/data/nn/torch_sequential_dataset.py,sha256=BqrK_PtkhpsaY1zRIWGk4EgwPL31a7IWCc0hLDuwDQc,10984
11
11
  replay/data/nn/utils.py,sha256=YKE9gkIHZDDiwv4THqOWL4PzsdOujnPuM97v79Mwq0E,2769
12
12
  replay/data/schema.py,sha256=F_cv6sYb6l23yuX5xWnbqoJ9oSeUT2NpIM19u8Lf2jA,15606
13
13
  replay/data/spark_schema.py,sha256=4o0Kn_fjwz2-9dBY3q46F9PL0F3E7jdVpIlX7SG3OZI,1111
14
- replay/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- replay/experimental/metrics/__init__.py,sha256=W6S9YTGCezLORyTKCqL4Y_PniC1k3Bu5XWIM3WVHg2Q,2860
16
- replay/experimental/metrics/base_metric.py,sha256=aYmKZ_336dRrlslBzYsgsOzmed54BNjNXsRcpzB5gyM,22648
17
- replay/experimental/metrics/coverage.py,sha256=3kVBAUhIEOuD8aJ6DShH2xh_1F61dcLZb001VCkmeJk,3154
18
- replay/experimental/metrics/experiment.py,sha256=Bd_XB9zbngcAwf5JLZKVPsFWQoz9pEGlPEUbkiR_MDc,7343
19
- replay/experimental/metrics/hitrate.py,sha256=TfWJrUyZXabdMr4tn8zqUPGDcYy2yphVCzXmLSHCxY0,675
20
- replay/experimental/metrics/map.py,sha256=S4dKiMpYR0_pu0bqioGMT0kIC1s2aojFP4rddBqMPtM,921
21
- replay/experimental/metrics/mrr.py,sha256=q6I1Cndlwr716mMuYtTMu0lN8Rrp9khxhb49OM2IpV8,530
22
- replay/experimental/metrics/ncis_precision.py,sha256=yrErOhBZvZdNpQPx_AXyktDJatqdWRIHNMyei0QDJtQ,1088
23
- replay/experimental/metrics/ndcg.py,sha256=q3KTsyZCrfvcpEjEnR_kWVB9ZaTFRxnoNRAr2WD0TrU,1538
24
- replay/experimental/metrics/precision.py,sha256=U9pD9yRGeT8uH32BTyQ-W5qsAnbFWu-pqy4XfkcXfCM,664
25
- replay/experimental/metrics/recall.py,sha256=5xRPGxfbVoDFEI5E6dVlZpT4RvnDlWzaktyoqh3a8mc,774
26
- replay/experimental/metrics/rocauc.py,sha256=yq4vW2_bXO8HCjREBZVrHMKeZ054LYvjJmLJTXWPfQA,1675
27
- replay/experimental/metrics/surprisal.py,sha256=CK4_zed2bSMDwC7ZBCS8d8RwGEqt8bh3w3fTpjKiK6Y,3052
28
- replay/experimental/metrics/unexpectedness.py,sha256=JQQXEYHtQM8nqp7X2He4E9ZYwbpdENaK8oQG7sUQT3s,2621
29
- replay/experimental/models/__init__.py,sha256=R284PXgSxt-JWWwlSTLggchash0hrLfy4b2w-ySaQf4,588
30
- replay/experimental/models/admm_slim.py,sha256=Oz-x0aQAnGFN9z7PB7MiKfduBasc4KQrBT0JwtYdwLY,6581
31
- replay/experimental/models/base_neighbour_rec.py,sha256=pRcffr0cdRNZRVpzWb2Qv-UIsLkhbs7K1GRAmrSqPSM,7506
32
- replay/experimental/models/base_rec.py,sha256=rj2r7r_mmJdzKAkg5CHG1eqJhOpUHAETPe0NwfibFjU,49606
33
- replay/experimental/models/base_torch_rec.py,sha256=oDkCxVFQjIHSWKlCns6mU3ECWbQW3mQZWvBHBxJQdwc,8111
34
- replay/experimental/models/cql.py,sha256=9ONDMblfxUgol5Pb2UInfSHVRbB2Ma15zAZC6valhtk,19628
35
- replay/experimental/models/ddpg.py,sha256=sZrGgwj_kKeUnwwT9qooc4Cxz2oVGkNfUwUe1N7mreI,31982
36
- replay/experimental/models/dt4rec/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
- replay/experimental/models/dt4rec/dt4rec.py,sha256=ZIHYonDubStN7Gb703csy86R7Q3_1fZc4zJf98HYFe4,5895
38
- replay/experimental/models/dt4rec/gpt1.py,sha256=T3buFtYyF6Fh6sW6f9dUZFcFEnQdljItbRa22CiKb0w,14044
39
- replay/experimental/models/dt4rec/trainer.py,sha256=YeaJ8mnoYZqnPwm1P9qOYb8GzgFC5At-JeSDcvG2V2o,3859
40
- replay/experimental/models/dt4rec/utils.py,sha256=jbCx2Xc85VtjQx-caYhJFfVuj1Wf866OAiSoZlR4q48,8201
41
- replay/experimental/models/extensions/spark_custom_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
- replay/experimental/models/extensions/spark_custom_models/als_extension.py,sha256=dKSVCMXWRB7IUnpEK_QNhSEuUSVcG793E8MT_AGXneY,25890
43
- replay/experimental/models/implicit_wrap.py,sha256=8F-f-CaStmlNHwphu-yu8o4Aft08NKDD_SqqH0zp1Uo,4655
44
- replay/experimental/models/lightfm_wrap.py,sha256=a2ctIEoZf7I0C_awiQI1lE4RGJ7ISs60znysgHRXZCw,11337
45
- replay/experimental/models/mult_vae.py,sha256=FdJ-GL6Jj2l5-38edKp_jsNfwFNGPxMHXKn8cG2tGJs,11607
46
- replay/experimental/models/neuromf.py,sha256=QRu--zIyOSQIp8R5Ksgiw7o0s5yOhQpuAX9YshKJs4w,14391
47
- replay/experimental/models/scala_als.py,sha256=PVf0YA3ii4iRwGqpYg6nStgaauyrm9QTzLtK_4f1En0,10985
48
- replay/experimental/nn/data/__init__.py,sha256=5EAF-FNd7xhkUpTq_5MyVcPXBD81mJCwYrcbhdGOWjE,48
49
- replay/experimental/nn/data/schema_builder.py,sha256=5PphL9kK-tVm30nWdTjHUzqVOnTwKiU_MlxGdL5HJ8Y,1736
50
- replay/experimental/preprocessing/__init__.py,sha256=uMyeyQ_GKqjLhVGwhrEk3NLhhzS0DKi5xGo3VF4WkiA,130
51
- replay/experimental/preprocessing/data_preparator.py,sha256=fQ8Blo_uzA-2eC-_ViVeU26Tqj5lxLTCBoDJfEmiqUo,35968
52
- replay/experimental/preprocessing/padder.py,sha256=o7S_Zk-ne_jria3QhWCKkYa6bEqhCdtvCA-R0MjOvU4,9569
53
- replay/experimental/preprocessing/sequence_generator.py,sha256=E1_0uZJLv8V_n6YzRlgUWtcrHIdjNwPeBN-BMbz0e-A,9053
54
- replay/experimental/scenarios/__init__.py,sha256=gWFLCkLyOmOppvbRMK7C3UMlMpcbIgiGVolSH6LPgWA,91
55
- replay/experimental/scenarios/obp_wrapper/__init__.py,sha256=rsRyfsTnVNp20LkTEugwoBrV9XWbIhR8tOqec_Au6dY,450
56
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py,sha256=vmLANYB5i1UR3uY7e-T0IBEYwPxOYHtKqhkmUvMUYhU,2548
57
- replay/experimental/scenarios/obp_wrapper/replay_offline.py,sha256=A6TPBFHj_UUL0N6DHSF0-hACsH5cw2o1GMYvpPS6964,8756
58
- replay/experimental/scenarios/obp_wrapper/utils.py,sha256=-ioWTb73NmHWxVxw4BdSolctqeeGIyjKtydwc45nrrk,3271
59
- replay/experimental/scenarios/two_stages/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
- replay/experimental/scenarios/two_stages/reranker.py,sha256=tJtWhbHRNV4sJZ9RZzqIfylTplKh9QVwTIBhEGGnXq8,4244
61
- replay/experimental/scenarios/two_stages/two_stages_scenario.py,sha256=ZgflnQ6xuxDFphdKX6Q0jtXidHS7c2YvDaccoaL78Qo,29846
62
- replay/experimental/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
- replay/experimental/utils/logger.py,sha256=UwLowaeOG17sDEe32LiZel8MnjSTzeW7J3uLG1iwLuA,639
64
- replay/experimental/utils/model_handler.py,sha256=0ksSm5bJ1bL32VV5HI-KPe0a1EAzzOhMtmSYaM_zRrE,6271
65
- replay/experimental/utils/session_handler.py,sha256=076TLpTOcnh13BznNTtJW6Zhrqvm9Ee1mlpP5YMD4No,1313
66
14
  replay/metrics/__init__.py,sha256=KDkxVnKa4ks9K9GmlrdTx1pkIl-MAmm78ZASsp2ZndE,2812
67
15
  replay/metrics/base_metric.py,sha256=uleW5vLrdA3iRx72tFyW0cxe6ne_ugQ1XaY_ZTcnAOo,15960
68
16
  replay/metrics/categorical_diversity.py,sha256=OYsF-Ng-WrF9CC-sKgQKngrA779NO8MtgRvvAyC8MXM,10781
@@ -148,14 +96,14 @@ replay/optimization/__init__.py,sha256=az6U10rF7X6rPRUUPwLyiM1WFNJ_6kl0imA5xLVWF
148
96
  replay/optimization/optuna_objective.py,sha256=Z-8X0_FT3BicVWj0UhxoLrvZAck3Dhn7jHDGo0i0hxA,7653
149
97
  replay/preprocessing/__init__.py,sha256=TtBysFqYeDy4kZAEnWEaNSwPvbffYdfMkEs71YG51fM,411
150
98
  replay/preprocessing/converter.py,sha256=DczqsVLrwFi6EFhK2HR8rGiIxGCwXeY7QNgWorjA41g,4390
151
- replay/preprocessing/filters.py,sha256=6MaO4IIyKNFP2AR94YA5iQUhQvuCRhAFfj0opI6o4-Q,33744
99
+ replay/preprocessing/filters.py,sha256=wsXWQoZ-2aAecunLkaTxeLWi5ow4e3FAGcElx0iNx0w,41669
152
100
  replay/preprocessing/history_based_fp.py,sha256=tfgKJPKm53LSNqM6VmMXYsVrRDc-rP1Tbzn8s3mbziQ,18751
153
101
  replay/preprocessing/label_encoder.py,sha256=MLBavPD-dB644as0E9ZJSE9-8QxGCB_IHek1w3xtqDI,27040
154
102
  replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
155
103
  replay/scenarios/__init__.py,sha256=kw2wRkPPinw0IBA20D83XQ3xeSudk3KuYAAA1Wdr8xY,93
156
104
  replay/scenarios/fallback.py,sha256=EeBmIR-5igzKR2m55bQRFyhxTkpJez6ZkCW449n8hWs,7130
157
105
  replay/splitters/__init__.py,sha256=DnqVMelrzLwR8fGQgcWN_8FipGs8T4XGSPOMW-L_x2g,454
158
- replay/splitters/base_splitter.py,sha256=qWW8Sueu0BrYt0WIxMbzooAC4-jhEmyd6pMND_H_qB0,7751
106
+ replay/splitters/base_splitter.py,sha256=hj9_GYDWllzv3XnxN6WHu1JKRRVjXo77vZEOLbF9v-s,7761
159
107
  replay/splitters/cold_user_random_splitter.py,sha256=gVwBVdn_0IOaLGT_UzJoS9AMaPhelZy-FpC5JQS1PhA,4136
160
108
  replay/splitters/k_folds.py,sha256=WH02_DP18A2ae893ysonmfLPB56_i1ETllTAwaCYekg,6218
161
109
  replay/splitters/last_n_splitter.py,sha256=r9kdq2JPi508C9ywjwc68an-iq27KsigMfHWLz0YohE,15346
@@ -165,16 +113,15 @@ replay/splitters/ratio_splitter.py,sha256=8zvuCn16Icc4ntQPKXJ5ArAWuJzCZ9NHZtgWct
165
113
  replay/splitters/time_splitter.py,sha256=iXhuafjBx7dWyJSy-TEVy1IUQBwMpA1gAiF4-GtRe2g,9031
166
114
  replay/splitters/two_stage_splitter.py,sha256=PWozxjjgjrVzdz6Sm9dcDTeH0bOA24reFzkk_N_TgbQ,17734
167
115
  replay/utils/__init__.py,sha256=vDJgOWq81fbBs-QO4ZDpdqR4KDyO1kMOOxBRi-5Gp7E,253
168
- replay/utils/common.py,sha256=6JxR5bFuTFTFWad36J5Zu8dFgpFXoof6VsVpF2sD7h8,1471
116
+ replay/utils/common.py,sha256=s4Pro3QCkPeVBsj-s0vrbhd_pkJD-_-2M_sIguxGzQQ,5411
169
117
  replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
170
118
  replay/utils/distributions.py,sha256=kGGq2KzQZ-yhTuw_vtOsKFXVpXUOQ2l4aIFBcaDufZ8,1202
171
119
  replay/utils/model_handler.py,sha256=V-mHDh8_UexjVSsMBBRA9yrjS_5MPHwYOwv_UrI-Zfs,6466
172
120
  replay/utils/session_handler.py,sha256=ijTvDSNAe1D9R1e-dhtd-r80tFNiIBsFdWZLgw-gLEo,5153
173
- replay/utils/spark_utils.py,sha256=PhNi9fW28ek0ZB90AUg3tsT5BULbQjDhLalxxww9eLE,26700
121
+ replay/utils/spark_utils.py,sha256=k5lUFM2C9QZKQON3dqhgfswyUF4tsgJOn0U2wCKimqM,26901
174
122
  replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
175
123
  replay/utils/types.py,sha256=5sw0A7NG4ZgQKdWORnBy0wBZ5F98sP_Ju8SKQ6zbDS4,651
176
- replay_rec-0.17.0rc0.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
177
- replay_rec-0.17.0rc0.dist-info/METADATA,sha256=8Ki81O8-t1bWieQu4WJFFNWMu4CrvhwBSaU0mcfhh4o,10889
178
- replay_rec-0.17.0rc0.dist-info/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
179
- replay_rec-0.17.0rc0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
180
- replay_rec-0.17.0rc0.dist-info/RECORD,,
124
+ replay_rec-0.17.1.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
125
+ replay_rec-0.17.1.dist-info/METADATA,sha256=IDkSzO_PcQgyU4Xqnpi0WTHkqyVS0t3vNvisONZaBLg,10589
126
+ replay_rec-0.17.1.dist-info/WHEEL,sha256=Zb28QaM1gQi8f4VCBhsUklF61CTlNYfs9YAZn-TOGFk,88
127
+ replay_rec-0.17.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.0
2
+ Generator: poetry-core 1.6.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
File without changes
@@ -1,61 +0,0 @@
1
- """
2
- Most metrics require dataframe with recommendations
3
- and dataframe with ground truth values —
4
- which objects each user interacted with.
5
-
6
- - recommendations (Union[pandas.DataFrame, spark.DataFrame]):
7
- predictions of a recommender system,
8
- DataFrame with columns ``[user_id, item_id, relevance]``
9
- - ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
10
- test data, DataFrame with columns
11
- ``[user_id, item_id, timestamp, relevance]``
12
-
13
- Metric is calculated for all users, presented in ``ground_truth``
14
- for accurate metric calculation in case when the recommender system generated
15
- recommendation not for all users. It is assumed, that all users,
16
- we want to calculate metric for, have positive interactions.
17
-
18
- But if we have users, who observed the recommendations, but have not responded,
19
- those users will be ignored and metric will be overestimated.
20
- For such case we propose additional optional parameter ``ground_truth_users``,
21
- the dataframe with all users, which should be considered during the metric calculation.
22
-
23
- - ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
24
- full list of users to calculate metric for, DataFrame with ``user_id`` column
25
-
26
- Every metric is calculated using top ``K`` items for each user.
27
- It is also possible to calculate metrics
28
- using multiple values for ``K`` simultaneously.
29
- In this case the result will be a dictionary and not a number.
30
-
31
- Make sure your recommendations do not contain user-item duplicates
32
- as duplicates could lead to the wrong calculation results.
33
-
34
- - k (Union[Iterable[int], int]):
35
- a single number or a list, specifying the
36
- truncation length for recommendation list for each user
37
-
38
- By default, metrics are averaged by users,
39
- but you can alternatively use method ``metric.median``.
40
- Also, you can get the lower bound
41
- of ``conf_interval`` for a given ``alpha``.
42
-
43
- Diversity metrics require extra parameters on initialization stage,
44
- but do not use ``ground_truth`` parameter.
45
-
46
- For each metric, a formula for its calculation is given, because this is
47
- important for the correct comparison of algorithms, as mentioned in our
48
- `article <https://arxiv.org/abs/2206.12858>`_.
49
- """
50
- from replay.experimental.metrics.base_metric import Metric, NCISMetric
51
- from replay.experimental.metrics.coverage import Coverage
52
- from replay.experimental.metrics.hitrate import HitRate
53
- from replay.experimental.metrics.map import MAP
54
- from replay.experimental.metrics.mrr import MRR
55
- from replay.experimental.metrics.ncis_precision import NCISPrecision
56
- from replay.experimental.metrics.ndcg import NDCG
57
- from replay.experimental.metrics.precision import Precision
58
- from replay.experimental.metrics.recall import Recall
59
- from replay.experimental.metrics.rocauc import RocAuc
60
- from replay.experimental.metrics.surprisal import Surprisal
61
- from replay.experimental.metrics.unexpectedness import Unexpectedness