replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +11 -0
  3. replay/data/nn/__init__.py +3 -0
  4. replay/data/nn/parquet/__init__.py +22 -0
  5. replay/data/nn/parquet/collate.py +29 -0
  6. replay/data/nn/parquet/constants/batches.py +8 -0
  7. replay/data/nn/parquet/constants/device.py +3 -0
  8. replay/data/nn/parquet/constants/filesystem.py +3 -0
  9. replay/data/nn/parquet/constants/metadata.py +5 -0
  10. replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
  11. replay/data/nn/parquet/impl/array_1d_column.py +140 -0
  12. replay/data/nn/parquet/impl/array_2d_column.py +160 -0
  13. replay/data/nn/parquet/impl/column_protocol.py +17 -0
  14. replay/data/nn/parquet/impl/indexing.py +123 -0
  15. replay/data/nn/parquet/impl/masking.py +20 -0
  16. replay/data/nn/parquet/impl/named_columns.py +100 -0
  17. replay/data/nn/parquet/impl/numeric_column.py +110 -0
  18. replay/data/nn/parquet/impl/utils.py +17 -0
  19. replay/data/nn/parquet/info/distributed_info.py +40 -0
  20. replay/data/nn/parquet/info/partitioning.py +132 -0
  21. replay/data/nn/parquet/info/replicas.py +67 -0
  22. replay/data/nn/parquet/info/worker_info.py +43 -0
  23. replay/data/nn/parquet/iterable_dataset.py +119 -0
  24. replay/data/nn/parquet/iterator.py +61 -0
  25. replay/data/nn/parquet/metadata/__init__.py +19 -0
  26. replay/data/nn/parquet/metadata/metadata.py +116 -0
  27. replay/data/nn/parquet/parquet_dataset.py +176 -0
  28. replay/data/nn/parquet/parquet_module.py +178 -0
  29. replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
  30. replay/data/nn/parquet/utils/compute_length.py +66 -0
  31. replay/data/nn/schema.py +12 -14
  32. replay/data/nn/sequence_tokenizer.py +5 -0
  33. replay/data/nn/sequential_dataset.py +4 -0
  34. replay/data/nn/torch_sequential_dataset.py +5 -0
  35. replay/data/utils/batching.py +69 -0
  36. replay/data/utils/typing/__init__.py +0 -0
  37. replay/data/utils/typing/dtype.py +65 -0
  38. replay/metrics/torch_metrics_builder.py +20 -14
  39. replay/models/nn/loss/sce.py +2 -7
  40. replay/models/nn/optimizer_utils/__init__.py +6 -1
  41. replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
  42. replay/models/nn/sequential/bert4rec/dataset.py +70 -29
  43. replay/models/nn/sequential/bert4rec/lightning.py +97 -36
  44. replay/models/nn/sequential/bert4rec/model.py +11 -11
  45. replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
  46. replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
  47. replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
  48. replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
  49. replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
  50. replay/models/nn/sequential/postprocessors/_base.py +5 -0
  51. replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
  52. replay/models/nn/sequential/sasrec/dataset.py +81 -26
  53. replay/models/nn/sequential/sasrec/lightning.py +86 -24
  54. replay/models/nn/sequential/sasrec/model.py +14 -9
  55. replay/nn/__init__.py +8 -0
  56. replay/nn/agg.py +109 -0
  57. replay/nn/attention.py +158 -0
  58. replay/nn/embedding.py +283 -0
  59. replay/nn/ffn.py +135 -0
  60. replay/nn/head.py +49 -0
  61. replay/nn/lightning/__init__.py +1 -0
  62. replay/nn/lightning/callback/__init__.py +9 -0
  63. replay/nn/lightning/callback/metrics_callback.py +183 -0
  64. replay/nn/lightning/callback/predictions_callback.py +314 -0
  65. replay/nn/lightning/module.py +123 -0
  66. replay/nn/lightning/optimizer.py +60 -0
  67. replay/nn/lightning/postprocessor/__init__.py +2 -0
  68. replay/nn/lightning/postprocessor/_base.py +51 -0
  69. replay/nn/lightning/postprocessor/seen_items.py +83 -0
  70. replay/nn/lightning/scheduler.py +91 -0
  71. replay/nn/loss/__init__.py +22 -0
  72. replay/nn/loss/base.py +197 -0
  73. replay/nn/loss/bce.py +216 -0
  74. replay/nn/loss/ce.py +317 -0
  75. replay/nn/loss/login_ce.py +373 -0
  76. replay/nn/loss/logout_ce.py +230 -0
  77. replay/nn/mask.py +87 -0
  78. replay/nn/normalization.py +9 -0
  79. replay/nn/output.py +37 -0
  80. replay/nn/sequential/__init__.py +9 -0
  81. replay/nn/sequential/sasrec/__init__.py +7 -0
  82. replay/nn/sequential/sasrec/agg.py +53 -0
  83. replay/nn/sequential/sasrec/diff_transformer.py +125 -0
  84. replay/nn/sequential/sasrec/model.py +377 -0
  85. replay/nn/sequential/sasrec/transformer.py +107 -0
  86. replay/nn/sequential/twotower/__init__.py +2 -0
  87. replay/nn/sequential/twotower/model.py +674 -0
  88. replay/nn/sequential/twotower/reader.py +89 -0
  89. replay/nn/transform/__init__.py +22 -0
  90. replay/nn/transform/copy.py +38 -0
  91. replay/nn/transform/grouping.py +39 -0
  92. replay/nn/transform/negative_sampling.py +182 -0
  93. replay/nn/transform/next_token.py +100 -0
  94. replay/nn/transform/rename.py +33 -0
  95. replay/nn/transform/reshape.py +41 -0
  96. replay/nn/transform/sequence_roll.py +48 -0
  97. replay/nn/transform/template/__init__.py +2 -0
  98. replay/nn/transform/template/sasrec.py +53 -0
  99. replay/nn/transform/template/twotower.py +22 -0
  100. replay/nn/transform/token_mask.py +69 -0
  101. replay/nn/transform/trim.py +51 -0
  102. replay/nn/utils.py +28 -0
  103. replay/preprocessing/filters.py +128 -0
  104. replay/preprocessing/label_encoder.py +36 -33
  105. replay/preprocessing/utils.py +209 -0
  106. replay/splitters/__init__.py +1 -0
  107. replay/splitters/random_next_n_splitter.py +224 -0
  108. replay/utils/common.py +10 -4
  109. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
  110. replay_rec-0.21.0.dist-info/RECORD +223 -0
  111. replay/experimental/metrics/__init__.py +0 -62
  112. replay/experimental/metrics/base_metric.py +0 -603
  113. replay/experimental/metrics/coverage.py +0 -97
  114. replay/experimental/metrics/experiment.py +0 -175
  115. replay/experimental/metrics/hitrate.py +0 -26
  116. replay/experimental/metrics/map.py +0 -30
  117. replay/experimental/metrics/mrr.py +0 -18
  118. replay/experimental/metrics/ncis_precision.py +0 -31
  119. replay/experimental/metrics/ndcg.py +0 -49
  120. replay/experimental/metrics/precision.py +0 -22
  121. replay/experimental/metrics/recall.py +0 -25
  122. replay/experimental/metrics/rocauc.py +0 -49
  123. replay/experimental/metrics/surprisal.py +0 -90
  124. replay/experimental/metrics/unexpectedness.py +0 -76
  125. replay/experimental/models/__init__.py +0 -50
  126. replay/experimental/models/admm_slim.py +0 -257
  127. replay/experimental/models/base_neighbour_rec.py +0 -200
  128. replay/experimental/models/base_rec.py +0 -1386
  129. replay/experimental/models/base_torch_rec.py +0 -234
  130. replay/experimental/models/cql.py +0 -454
  131. replay/experimental/models/ddpg.py +0 -932
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  133. replay/experimental/models/dt4rec/gpt1.py +0 -401
  134. replay/experimental/models/dt4rec/trainer.py +0 -127
  135. replay/experimental/models/dt4rec/utils.py +0 -264
  136. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  137. replay/experimental/models/hierarchical_recommender.py +0 -331
  138. replay/experimental/models/implicit_wrap.py +0 -131
  139. replay/experimental/models/lightfm_wrap.py +0 -303
  140. replay/experimental/models/mult_vae.py +0 -332
  141. replay/experimental/models/neural_ts.py +0 -986
  142. replay/experimental/models/neuromf.py +0 -406
  143. replay/experimental/models/scala_als.py +0 -293
  144. replay/experimental/models/u_lin_ucb.py +0 -115
  145. replay/experimental/nn/data/__init__.py +0 -1
  146. replay/experimental/nn/data/schema_builder.py +0 -102
  147. replay/experimental/preprocessing/__init__.py +0 -3
  148. replay/experimental/preprocessing/data_preparator.py +0 -839
  149. replay/experimental/preprocessing/padder.py +0 -229
  150. replay/experimental/preprocessing/sequence_generator.py +0 -208
  151. replay/experimental/scenarios/__init__.py +0 -1
  152. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  153. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  154. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  155. replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  156. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  157. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  158. replay/experimental/utils/logger.py +0 -24
  159. replay/experimental/utils/model_handler.py +0 -186
  160. replay/experimental/utils/session_handler.py +0 -44
  161. replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
  162. /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
  163. /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
  164. /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
  165. /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
  166. /replay/{experimental → data}/utils/__init__.py +0 -0
  167. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
  168. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
  169. {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,224 @@
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import polars as pl
6
+
7
+ from replay.utils import (
8
+ PYSPARK_AVAILABLE,
9
+ DataFrameLike,
10
+ PandasDataFrame,
11
+ PolarsDataFrame,
12
+ SparkDataFrame,
13
+ )
14
+
15
+ from .base_splitter import Splitter
16
+
17
+ if PYSPARK_AVAILABLE:
18
+ import pyspark.sql.functions as sf
19
+ from pyspark.sql import Window
20
+
21
+
22
+ class RandomNextNSplitter(Splitter):
23
+ """
24
+ Split interactions by a random position in the user sequence.
25
+ For each user, a random cut index is sampled and the target part consists of
26
+ the next ``N`` interactions starting from this cut; the train part contains
27
+ all interactions before the cut. Interactions after the target window are
28
+ discarded.
29
+
30
+ Note: by changing the ``seed`` attribute on an existing splitter instance,
31
+ you can obtain different splits without recreating the object. This is useful
32
+ when you need to generate multiple randomized splits of the same dataset.
33
+
34
+ >>> from datetime import datetime
35
+ >>> import pandas as pd
36
+ >>> columns = ["query_id", "item_id", "timestamp"]
37
+ >>> data = [
38
+ ... (1, 1, "01-01-2020"),
39
+ ... (1, 2, "02-01-2020"),
40
+ ... (1, 3, "03-01-2020"),
41
+ ... (2, 1, "06-01-2020"),
42
+ ... (2, 2, "07-01-2020"),
43
+ ... (2, 3, "08-01-2020"),
44
+ ... ]
45
+ >>> dataset = pd.DataFrame(data, columns=columns)
46
+ >>> dataset["timestamp"] = pd.to_datetime(dataset["timestamp"], format="%d-%m-%Y")
47
+ >>> splitter = RandomNextNSplitter(
48
+ ... N=2,
49
+ ... divide_column="query_id",
50
+ ... seed=42,
51
+ ... query_column="query_id",
52
+ ... item_column="item_id",
53
+ ... )
54
+ >>> train, test = splitter.split(dataset)
55
+ """
56
+
57
+ _init_arg_names = [
58
+ "N",
59
+ "divide_column",
60
+ "seed",
61
+ "drop_cold_users",
62
+ "drop_cold_items",
63
+ "query_column",
64
+ "item_column",
65
+ "timestamp_column",
66
+ "session_id_column",
67
+ "session_id_processing_strategy",
68
+ ]
69
+
70
+ def __init__(
71
+ self,
72
+ N: Optional[int] = 1, # noqa: N803
73
+ divide_column: str = "query_id",
74
+ seed: Optional[int] = None,
75
+ query_column: str = "query_id",
76
+ drop_cold_users: bool = False,
77
+ drop_cold_items: bool = False,
78
+ item_column: str = "item_id",
79
+ timestamp_column: str = "timestamp",
80
+ session_id_column: Optional[str] = None,
81
+ session_id_processing_strategy: str = "test",
82
+ ):
83
+ """
84
+ :param N: Optional window size. If None, the test set contains all interactions
85
+ from the cut to the end; otherwise the next ``N`` interactions. Must be >= 1.
86
+ Default: 1.
87
+ :param divide_column: Name of the column used to group interactions
88
+ for random cut sampling, default: ``query_id``.
89
+ :param seed: Random seed used to sample cut indices, default: ``None``.
90
+ :param query_column: Name of query interaction column.
91
+ :param drop_cold_users: Drop users from test DataFrame which are not in
92
+ the train DataFrame, default: ``False``.
93
+ :param drop_cold_items: Drop items from test DataFrame which are not in
94
+ the train DataFrame, default: ``False``.
95
+ :param item_column: Name of item interaction column.
96
+ If ``drop_cold_items`` is ``False``, then you can omit this
97
+ parameter. Default: ``item_id``.
98
+ :param timestamp_column: Name of time column, default: ``timestamp``.
99
+ :param session_id_column: Name of session id column whose values cannot
100
+ be split between train/test, default: ``None``.
101
+ :param session_id_processing_strategy: Strategy to process a session if
102
+ it crosses the boundary: ``train`` or ``test``. ``train`` means the
103
+ whole session goes to train, ``test`` — the whole session goes to
104
+ test. Default: ``test``.
105
+ """
106
+
107
+ super().__init__(
108
+ drop_cold_users=drop_cold_users,
109
+ drop_cold_items=drop_cold_items,
110
+ query_column=query_column,
111
+ item_column=item_column,
112
+ timestamp_column=timestamp_column,
113
+ session_id_column=session_id_column,
114
+ session_id_processing_strategy=session_id_processing_strategy,
115
+ )
116
+ self.N = N
117
+ if self.N is not None and self.N < 1:
118
+ msg = "N must be >= 1"
119
+ raise ValueError(msg)
120
+ self.divide_column = divide_column
121
+ self.seed = seed
122
+
123
+ def _sample_cuts(self, counts: np.ndarray) -> np.ndarray:
124
+ rng = np.random.RandomState(self.seed)
125
+ return rng.randint(0, counts)
126
+
127
+ def _partial_split_pandas(
128
+ self,
129
+ interactions: PandasDataFrame,
130
+ ) -> tuple[PandasDataFrame, PandasDataFrame]:
131
+ df = interactions.sort_values([self.divide_column, self.timestamp_column])
132
+ df["_event_rank"] = df.groupby(self.divide_column, sort=False).cumcount()
133
+
134
+ counts = df.groupby(self.divide_column, sort=False).size()
135
+ cuts = pd.Series(self._sample_cuts(counts.values), index=counts.index)
136
+ df["_cut_index"] = df[self.divide_column].map(cuts)
137
+
138
+ if self.N is not None:
139
+ df = df[df["_event_rank"] < df["_cut_index"] + self.N]
140
+
141
+ df["is_test"] = df["_event_rank"] >= df["_cut_index"]
142
+ if self.session_id_column:
143
+ df = self._recalculate_with_session_id_column(df)
144
+
145
+ train = df[~df["is_test"]][interactions.columns]
146
+ test = df[df["is_test"]][interactions.columns]
147
+
148
+ return train, test
149
+
150
+ def _partial_split_polars(
151
+ self,
152
+ interactions: PolarsDataFrame,
153
+ ) -> tuple[PolarsDataFrame, PolarsDataFrame]:
154
+ df = interactions.sort([self.divide_column, self.timestamp_column]).with_columns(
155
+ (pl.col(self.divide_column).cum_count().over(self.divide_column) - 1).alias("_event_rank")
156
+ )
157
+
158
+ counts = df.group_by(self.divide_column).len()
159
+ r_values = self._sample_cuts(counts["len"].to_numpy())
160
+ cuts_df = pl.DataFrame(
161
+ {
162
+ self.divide_column: counts[self.divide_column],
163
+ "_cut_index": r_values,
164
+ }
165
+ )
166
+ df = df.join(cuts_df, on=self.divide_column, how="left")
167
+
168
+ if self.N is not None:
169
+ df = df.filter(pl.col("_event_rank") < pl.col("_cut_index") + self.N)
170
+
171
+ df = df.with_columns((pl.col("_event_rank") >= pl.col("_cut_index")).alias("is_test"))
172
+ if self.session_id_column:
173
+ df = self._recalculate_with_session_id_column(df)
174
+
175
+ train = df.filter(~pl.col("is_test")).select(interactions.columns)
176
+ test = df.filter(pl.col("is_test")).select(interactions.columns)
177
+
178
+ return train, test
179
+
180
+ def _partial_split_spark(
181
+ self,
182
+ interactions: SparkDataFrame,
183
+ ) -> tuple[SparkDataFrame, SparkDataFrame]:
184
+ w = Window.partitionBy(self.divide_column).orderBy(self.timestamp_column)
185
+ df = interactions.withColumn("_event_rank", sf.row_number().over(w) - sf.lit(1))
186
+
187
+ counts = df.groupBy(self.divide_column).agg(sf.count(sf.lit(1)).alias("_count"))
188
+ seed_lit = sf.lit(self.seed) if self.seed is not None else sf.lit(0)
189
+ cuts = counts.select(
190
+ self.divide_column,
191
+ sf.pmod(
192
+ sf.xxhash64(sf.col(self.divide_column), seed_lit).cast("long"),
193
+ sf.col("_count").cast("long"),
194
+ )
195
+ .cast("long")
196
+ .alias("_cut_index"),
197
+ )
198
+
199
+ df = df.join(cuts, on=self.divide_column, how="left")
200
+
201
+ if self.N is not None:
202
+ df = df.filter(sf.col("_event_rank") < sf.col("_cut_index") + sf.lit(self.N))
203
+
204
+ df = df.withColumn("is_test", sf.col("_event_rank") >= sf.col("_cut_index"))
205
+ if self.session_id_column:
206
+ df = self._recalculate_with_session_id_column(df)
207
+
208
+ train = df.filter(~sf.col("is_test")).select(interactions.columns)
209
+ test = df.filter(sf.col("is_test")).select(interactions.columns)
210
+
211
+ return train, test
212
+
213
+ def _partial_split(self, interactions: DataFrameLike) -> tuple[DataFrameLike, DataFrameLike]:
214
+ if isinstance(interactions, PandasDataFrame):
215
+ return self._partial_split_pandas(interactions)
216
+ if isinstance(interactions, PolarsDataFrame):
217
+ return self._partial_split_polars(interactions)
218
+ if isinstance(interactions, SparkDataFrame):
219
+ return self._partial_split_spark(interactions)
220
+ msg = f"{self} is not implemented for {type(interactions)}"
221
+ raise NotImplementedError(msg)
222
+
223
+ def _core_split(self, interactions: DataFrameLike) -> tuple[DataFrameLike, DataFrameLike]:
224
+ return self._partial_split(interactions)
replay/utils/common.py CHANGED
@@ -2,9 +2,10 @@ import functools
2
2
  import inspect
3
3
  import json
4
4
  from pathlib import Path
5
- from typing import Any, Callable, Union
5
+ from typing import Any, Callable, TypeVar, Union
6
6
 
7
7
  from polars import from_pandas as pl_from_pandas
8
+ from typing_extensions import ParamSpec
8
9
 
9
10
  from replay.data.dataset import Dataset
10
11
  from replay.preprocessing import (
@@ -16,6 +17,7 @@ from replay.splitters import (
16
17
  KFolds,
17
18
  LastNSplitter,
18
19
  NewUsersSplitter,
20
+ RandomNextNSplitter,
19
21
  RandomSplitter,
20
22
  RatioSplitter,
21
23
  TimeSplitter,
@@ -37,6 +39,7 @@ SavableObject = Union[
37
39
  KFolds,
38
40
  LastNSplitter,
39
41
  NewUsersSplitter,
42
+ RandomNextNSplitter,
40
43
  RandomSplitter,
41
44
  RatioSplitter,
42
45
  TimeSplitter,
@@ -56,6 +59,9 @@ if TORCH_AVAILABLE:
56
59
  PolarsSequentialDataset,
57
60
  ]
58
61
 
62
+ P = ParamSpec("P")
63
+ R = TypeVar("R")
64
+
59
65
 
60
66
  def save_to_replay(obj: SavableObject, path: Union[str, Path]) -> None:
61
67
  """
@@ -87,10 +93,10 @@ def _check_if_dataframe(var: Any):
87
93
  raise ValueError(msg)
88
94
 
89
95
 
90
- def check_if_dataframe(*args_to_check: str) -> Callable[..., Any]:
91
- def decorator_func(func: Callable[..., Any]) -> Callable[..., Any]:
96
+ def check_if_dataframe(*args_to_check: str) -> Callable[P, R]:
97
+ def decorator_func(func: Callable[P, R]) -> Callable[P, R]:
92
98
  @functools.wraps(func)
93
- def wrap_func(*args: Any, **kwargs: Any) -> Any:
99
+ def wrap_func(*args: P.args, **kwargs: P.kwargs) -> R:
94
100
  extended_kwargs = {}
95
101
  extended_kwargs.update(kwargs)
96
102
  extended_kwargs.update(dict(zip(inspect.signature(func).parameters.keys(), args)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.20.3rc0
3
+ Version: 0.21.0
4
4
  Summary: RecSys Library
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -14,23 +14,29 @@ Classifier: Intended Audience :: Developers
14
14
  Classifier: Intended Audience :: Science/Research
15
15
  Classifier: Natural Language :: English
16
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
- Requires-Dist: d3rlpy (>=2.8.1,<2.9)
18
- Requires-Dist: implicit (>=0.7.2,<0.8)
19
- Requires-Dist: lightautoml (>=0.4.1,<0.5)
20
- Requires-Dist: lightning (>=2.0.2,<=2.4.0)
21
- Requires-Dist: numba (>=0.50,<1)
17
+ Provides-Extra: spark
18
+ Provides-Extra: torch
19
+ Provides-Extra: torch-cpu
20
+ Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
21
+ Requires-Dist: lightning ; extra == "torch"
22
+ Requires-Dist: lightning ; extra == "torch-cpu"
22
23
  Requires-Dist: numpy (>=1.20.0,<2)
23
24
  Requires-Dist: pandas (>=1.3.5,<2.4.0)
24
25
  Requires-Dist: polars (<2.0)
25
- Requires-Dist: psutil (<=7.0.0)
26
+ Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
27
+ Requires-Dist: psutil ; extra == "spark"
26
28
  Requires-Dist: pyarrow (<22.0)
27
- Requires-Dist: pyspark (>=3.0,<3.5)
28
- Requires-Dist: pytorch-optimizer (>=3.8.0,<4)
29
- Requires-Dist: sb-obp (>=0.5.10,<0.6)
29
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
30
+ Requires-Dist: pyspark ; extra == "spark"
31
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
32
+ Requires-Dist: pytorch-optimizer ; extra == "torch"
33
+ Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
30
34
  Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
31
35
  Requires-Dist: scipy (>=1.8.1,<2.0.0)
32
36
  Requires-Dist: setuptools
33
- Requires-Dist: torch (>=1.8,<2.9.0)
37
+ Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
38
+ Requires-Dist: torch ; extra == "torch"
39
+ Requires-Dist: torch ; extra == "torch-cpu"
34
40
  Requires-Dist: tqdm (>=4.67,<5)
35
41
  Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
36
42
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
@@ -231,7 +237,7 @@ pip install optuna
231
237
 
232
238
  2) Model compilation via OpenVINO:
233
239
  ```bash
234
- pip install openvino onnx
240
+ pip install openvino onnx onnxscript
235
241
  ```
236
242
 
237
243
  3) Vector database and hierarchical search support:
@@ -0,0 +1,223 @@
1
+ replay/__init__.py,sha256=v3mrDhnKFg0X1ZQBAAyAMlOgyZDPiRd01VsfpkOu9bo,225
2
+ replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
3
+ replay/data/dataset.py,sha256=yBl-yJVIokgN4prFY949tHe2UVJC_j5xdaulIoSPvQI,31252
4
+ replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
5
+ replay/data/dataset_utils/dataset_label_encoder.py,sha256=bxuJPhShFZBok7bQZYGNMV1etCLNTJUpyKO5MIwWack,9823
6
+ replay/data/nn/__init__.py,sha256=Dpso6tN10moj92_NrXCVWBEAMhnGXewGC12H9fTCg0E,1228
7
+ replay/data/nn/parquet/__init__.py,sha256=e6FDBPzlv9SMduGJOtn2EarxPXk3_wHKWConS__SmWk,786
8
+ replay/data/nn/parquet/collate.py,sha256=tOArGUnJCILdAHEHELW7o3iuKCVD4w8BEbxNYXv7yJc,984
9
+ replay/data/nn/parquet/constants/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ replay/data/nn/parquet/constants/batches.py,sha256=2VJwk_W9wOk6C1P4IMvlO5bWuP1i14TCFGHaL487TkI,271
11
+ replay/data/nn/parquet/constants/device.py,sha256=EV25_HKMiPyAx7pops1Vr3YVR-9CmW_cpmxlymmyg9Q,51
12
+ replay/data/nn/parquet/constants/filesystem.py,sha256=v23OWKtTDFnCqCYuhL4d9o-PDjihfYIMVkatWzhqoiQ,67
13
+ replay/data/nn/parquet/constants/metadata.py,sha256=UQdTtnMPwGkpwgogrIij5C9G_HdKCZXKoGur0KFIdCM,133
14
+ replay/data/nn/parquet/fixed_batch_dataset.py,sha256=SFfyUkFaleZ4W_oskrl_6ws8f10Dkqo71U7C6g7yuD8,5150
15
+ replay/data/nn/parquet/impl/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ replay/data/nn/parquet/impl/array_1d_column.py,sha256=HgjJ4Dz1za_X_WfqrrXbjV_pPo5B_P3OJ3NXw8cjcBY,4823
17
+ replay/data/nn/parquet/impl/array_2d_column.py,sha256=2pjjmlF8Kbqi11uTtMYa29mWR1Aiieajvx3w96YrW50,5633
18
+ replay/data/nn/parquet/impl/column_protocol.py,sha256=Tjcbo3b2I834OMsaVz312AzvFMAiSydWDtN7_uMNzc0,340
19
+ replay/data/nn/parquet/impl/indexing.py,sha256=_ETICbsn-q70iEvUAIwgoZFqrIB2UxQVvqQ4kD3F8DY,4945
20
+ replay/data/nn/parquet/impl/masking.py,sha256=NBq6klPCAUO-Zm-VvCf1t6E_yLGbo03KvAq7Bl64ZsI,627
21
+ replay/data/nn/parquet/impl/named_columns.py,sha256=LUlI7tsh-6kfcVAAKjpPnZvGUwaCvHOj-Zqgqh73A14,3117
22
+ replay/data/nn/parquet/impl/numeric_column.py,sha256=A1jKct3YJegzzu8BoHHcypAgZUPFM3QR03FkdMQfxnI,3940
23
+ replay/data/nn/parquet/impl/utils.py,sha256=MqZcSC4fQnRrrxHFg21ukrntTQkASpEh6SKftOP20Ds,446
24
+ replay/data/nn/parquet/info/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ replay/data/nn/parquet/info/distributed_info.py,sha256=po7pl0m24pqjzlWKlfrdDWKZ4w0FEPoYLfW0KqqFCVY,850
26
+ replay/data/nn/parquet/info/partitioning.py,sha256=2XuiRlHaQbiRHrmPrmn2JNogiHvknmvIHB8IIK3pf78,4525
27
+ replay/data/nn/parquet/info/replicas.py,sha256=L2YnD6nvp_hjqhtptzq5KLdm8WlUskFnxbvL_06_AYQ,2480
28
+ replay/data/nn/parquet/info/worker_info.py,sha256=sIqBqHSeFdO00dDg_Mc_6UNXDQXGcu0iQVinRy84RUE,947
29
+ replay/data/nn/parquet/iterable_dataset.py,sha256=mQe2xvrpOU3vrVdy_tCxaFk45fqJv8mnQQS48-sQcqU,4246
30
+ replay/data/nn/parquet/iterator.py,sha256=X5KXtjdY_uSfMlP9IXBqMzSimBqlAZbYX_Y483q_3U8,2577
31
+ replay/data/nn/parquet/metadata/__init__.py,sha256=UZX60ANtjo6zX0p43hU9q8fBldVJNCEmGzXjHqz0MJQ,341
32
+ replay/data/nn/parquet/metadata/metadata.py,sha256=jJOL8mieXhX18FO9lgaP95MOtO1l7tY63ldxoOAvzwA,3459
33
+ replay/data/nn/parquet/parquet_dataset.py,sha256=pKthRppp0MstfNwOk9wMrE6wFvDecCtbTKWIri4HGr0,8017
34
+ replay/data/nn/parquet/parquet_module.py,sha256=g53lgb-bydDg5P27I4MODnnMcRi1qjpvAw3_QQ9UgxQ,8208
35
+ replay/data/nn/parquet/partitioned_iterable_dataset.py,sha256=BZEh2EiBKMZxi822-doyTbjDkZQQ62SxAp_NhZVZdmk,1938
36
+ replay/data/nn/parquet/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
+ replay/data/nn/parquet/utils/compute_length.py,sha256=VWabulpRICy-_Z0ZBXpEmhAIlpXVwTwe9kX2L2XCdbE,2492
38
+ replay/data/nn/schema.py,sha256=vLSDj4ZprOL9aurdcpOZ78KgNRXXuwt4wuTq5feiAvA,17115
39
+ replay/data/nn/sequence_tokenizer.py,sha256=zh026PRsTzPhUhW1SqPOvAZOdrIDbDyBRwdkgwtvTh0,37745
40
+ replay/data/nn/sequential_dataset.py,sha256=BcLkM_w3yF7F0EgPK5_jcreurh8k0fVJBoA9KJpp1fM,11800
41
+ replay/data/nn/torch_sequential_dataset.py,sha256=VQ3l3SQBFxIuXKr5FpVJNE-As3MgJ7SAa4Aeb0S2yNA,11874
42
+ replay/data/nn/utils.py,sha256=Ic3G4yZRIzBYXLmwP1VstlZXPNR7AYGCc5EyZAERp5c,3297
43
+ replay/data/schema.py,sha256=JmYLCrNgBS5oq4O_PT724Gr1pDurHEykcqV8Xaj0XTw,15922
44
+ replay/data/spark_schema.py,sha256=4o0Kn_fjwz2-9dBY3q46F9PL0F3E7jdVpIlX7SG3OZI,1111
45
+ replay/data/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
+ replay/data/utils/batching.py,sha256=jBNhRC5jqNe2pVVlmvFLvjTo86Ud0e_Lj2P0W2yNcKY,2268
47
+ replay/data/utils/typing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
+ replay/data/utils/typing/dtype.py,sha256=QJigLH7fv-xIb_s-R_70KTZxOgl2ZJkhEhf_txziRAY,1590
49
+ replay/metrics/__init__.py,sha256=j0PGvUehaPEZMNo9SQwJsnvzrS4bam9eHrRMQFLnMjY,2813
50
+ replay/metrics/base_metric.py,sha256=ejtwFHktN4J8Fi1HIM3w0zlMAd8nO7-XpFi2D1iHXUQ,16010
51
+ replay/metrics/categorical_diversity.py,sha256=3tp8n457Ob4gjM-UTB5N19u9WAF7fLDkWKk-Mth-Vzc,10769
52
+ replay/metrics/coverage.py,sha256=e6vPItrRlI-mLNuOT5uoo5lMAAzkYGKZRxvupi21dMk,8528
53
+ replay/metrics/descriptors.py,sha256=BHORyGKfJgPeUjgLO0u2urSTe16UQbb-HHh8soqnwDE,3893
54
+ replay/metrics/experiment.py,sha256=6Sw8PyItn3E2R-BBa_YwrmtBV3n0uAGHHOvkhHYgMz4,8125
55
+ replay/metrics/hitrate.py,sha256=LcOJLMs3_Dq4_pbKx95qdCdjGrX52dyWyuWUFXCyaDw,2314
56
+ replay/metrics/map.py,sha256=dIZcmUxd2XnNC7d_d7gmq0cjNaI1hlNMaJTSHGCokQE,2572
57
+ replay/metrics/mrr.py,sha256=qM8tVMSoyYR-kTx0mnBGppoC53SxNlZKm7JKMUmSv9U,2163
58
+ replay/metrics/ndcg.py,sha256=izajmD243ZIK3KLm9M-NtLwxb9N3Ktj58__AAfwF6Vc,3110
59
+ replay/metrics/novelty.py,sha256=j3p1fbUVi2QQgEre42jeQx73PYYDUhy5gYlrL4BL5b8,5488
60
+ replay/metrics/offline_metrics.py,sha256=f_U4Tk3Ke5sR0_OYvoE2_nD6wrOCveg3DM3B9pStVUI,20454
61
+ replay/metrics/precision.py,sha256=DRlsgY_b4bJCOSZjCA58N41REMiDt-dbagRSXxfXyvY,2256
62
+ replay/metrics/recall.py,sha256=fzpASDiH88zcpXJZTbStQ3nuzzSdhd9k1wjF27rM4wc,2447
63
+ replay/metrics/rocauc.py,sha256=1vaVEK7DQTL8BX-i7A64hTFWyO38aNycscPGrdWKwbA,3282
64
+ replay/metrics/surprisal.py,sha256=HkmYrOuw3jydxFrkidjdcpAcKz2DeOnMsKqwB2g9pwY,7526
65
+ replay/metrics/torch_metrics_builder.py,sha256=mnHrmRTOKZ_edrTrTKs7IPzKt5DkQYRd2B_8b3bB9yU,14071
66
+ replay/metrics/unexpectedness.py,sha256=LSi-z50l3_yrvLnmToHQzm6Ygf2QpNt_zhk6jdg7QUo,6882
67
+ replay/models/__init__.py,sha256=kECYluQZ83zRUWaHVvnt7Tg3BerHrJy9v8XfRxsqyYY,1123
68
+ replay/models/als.py,sha256=1MFAbcx64tv0MX1wE9CM1NxKD3F3ZDhZUrmt6dvHu74,6220
69
+ replay/models/association_rules.py,sha256=shBNsKjlii0YK-XA6bSl5Ov0ZFTnjxZbgKJU9PFYptY,14507
70
+ replay/models/base_neighbour_rec.py,sha256=SdGb2ejpYjHmxFNTk5zwEo0RWdfPAj1vKGP_oj7IrQo,7783
71
+ replay/models/base_rec.py,sha256=aNIEbSy8G5q92NOpDlSJbp0Z-lAkazFLa9eDAajl1wI,56067
72
+ replay/models/cat_pop_rec.py,sha256=ed1X1PDQY41hFJ1cO3Q5OWy0rXhV5_n23hJ-QHWONtE,11968
73
+ replay/models/cluster.py,sha256=9JcpGnbfgFa4UsyxPAa4WMuJFa3rsuAxiKoy-s_UfyE,4970
74
+ replay/models/common.py,sha256=rFmfwwzWCWED2HaDVuSN7ZUAgaNPGPawUudgn4IApbo,2121
75
+ replay/models/extensions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
76
+ replay/models/extensions/ann/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
+ replay/models/extensions/ann/ann_mixin.py,sha256=Ua1fuwrvtISNDQ8iPV-ln8S1LDKz8-rIU2UYsMExAiU,7782
78
+ replay/models/extensions/ann/entities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
79
+ replay/models/extensions/ann/entities/base_hnsw_param.py,sha256=5GRdcQj4-zhNXfJ7ko2WHGHgRuXCzSHCRcRxljl1V4c,776
80
+ replay/models/extensions/ann/entities/hnswlib_param.py,sha256=j3V4JXM_yfR6s2TjYakIXMg-zS1-MrP6an930DEIWGM,2104
81
+ replay/models/extensions/ann/entities/nmslib_hnsw_param.py,sha256=WeEhRR4jKqgvWK_zDK8fx6kEqc2e_bc0kubvqK3iV8c,2162
82
+ replay/models/extensions/ann/index_builders/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
+ replay/models/extensions/ann/index_builders/base_index_builder.py,sha256=Ul25G0FaNLOXUjrDXxZDTg7tLXlv1N6wR8kWjWICtZ0,2110
84
+ replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py,sha256=U8-3lRahyWmWkZ7tYuO-Avd1jX-lGh7JukC140wJ-WQ,1600
85
+ replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py,sha256=1NLWyAJGYgp46uUBhUYQyd0stmG6DhLh7U4JEne5TFw,1308
86
+ replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py,sha256=cf3LhBCRRN-lBYGlJbv8vnY-KVeHAleN5cVjvd58Ibs,2476
87
+ replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py,sha256=0DPJ3WAt0cZ5dmtZv87fmMEgYXWf8rM35f7CA_DgWZY,2618
88
+ replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py,sha256=AIkVnobesnTM5lrBSWf9gd0CySwFQ0vH_DjemfLS4Cs,1925
89
+ replay/models/extensions/ann/index_inferers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
90
+ replay/models/extensions/ann/index_inferers/base_inferer.py,sha256=I39aqEc2somfndrCd-KC3XYZnYSrJ2hGpR9y6wO93NA,2524
91
+ replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py,sha256=JjT4l_XAjzUOsTAE7OS88zAgPd_h_O44oUnn2kVr8E0,2477
92
+ replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py,sha256=CoY_oMfdcwnh87ceuSpHXu4Czle9xxeMisO8XJUuJLE,1717
93
+ replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py,sha256=tjuqbkztWBU4K6qp5LPFU_GOGJf2f4oXneExtUEVUzw,3128
94
+ replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py,sha256=S5eCBZlTXxEAeX6yeZGC7j56gOcJ7lMNb4Cs_5PEj9E,2203
95
+ replay/models/extensions/ann/index_inferers/utils.py,sha256=6IST2FPSY3nuYu5KqzRpd4FgdaV3GnQRQlxp9LN_yyA,641
96
+ replay/models/extensions/ann/index_stores/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
+ replay/models/extensions/ann/index_stores/base_index_store.py,sha256=u4l2ybAXX92ZMGK7NqqosbKF75QgFqhAMVadd5ePj6Y,910
98
+ replay/models/extensions/ann/index_stores/hdfs_index_store.py,sha256=0zDq9PdDOiD6HvtZlfjTbuJHfeTOWBTQ_HiuqZmoxtA,3090
99
+ replay/models/extensions/ann/index_stores/shared_disk_index_store.py,sha256=thl4T1uYU4Gtk4nBao_qK8CbFBdX1xmXNishxgfCd-I,2030
100
+ replay/models/extensions/ann/index_stores/spark_files_index_store.py,sha256=QP_8mE7EIBbePIe0AB-IWuJLRA5MR3wswCEt8oHzn-0,3617
101
+ replay/models/extensions/ann/index_stores/utils.py,sha256=6r2GP_EFCaCguolW857pb4lRS8rh6_Nv_Edso9_j5no,3756
102
+ replay/models/extensions/ann/utils.py,sha256=AgQvThi_DvEtakQeTno9hVZVWiWMFHKTjRcQ2wLa5vk,1222
103
+ replay/models/kl_ucb.py,sha256=L6vC2KsTBTTx4ckmGhWybOiLa5Wt54N7cgl7jS2FQRg,6731
104
+ replay/models/knn.py,sha256=HEiGHHQg9pV1_EIWZHfK-XD0BNAm1bj1c0ND9rYnj3k,8992
105
+ replay/models/lin_ucb.py,sha256=iAR3PbbaQKqmisOKEx9ZyfpxnxcZomr6YauG4mvSakU,18800
106
+ replay/models/nn/__init__.py,sha256=AT3o1qXaxUq4_QIGlcGuSs54ZpueOo-SbpZwuGI-6os,41
107
+ replay/models/nn/loss/__init__.py,sha256=s3iO9QTZvLz_ony2b5K0hEmDmitrXQnAe9j6BRxLpR4,53
108
+ replay/models/nn/loss/sce.py,sha256=mRJZYmwQNT-kMi66pXrE1-OdM7y_YEQFHzC37odnEo4,5165
109
+ replay/models/nn/optimizer_utils/__init__.py,sha256=9aiEk662v9-qJgzi8TZYaaqQSiZWr4ZleFHwcLOZX14,219
110
+ replay/models/nn/optimizer_utils/optimizer_factory.py,sha256=OJPX4XD_uG0iZKzxmhzT60uS66swGtpZnAV5A98vcgs,3439
111
+ replay/models/nn/sequential/__init__.py,sha256=CI2n0cxs_amqJrwBMq6n0Z_uBOu7CGXfagqvE4Jlmjw,128
112
+ replay/models/nn/sequential/bert4rec/__init__.py,sha256=JfZqHOGxcvOkICl5cWmZbZhaKXpkIvua-Wj57VWWEhw,399
113
+ replay/models/nn/sequential/bert4rec/dataset.py,sha256=Y63LESZYqKDG3OSvrWRy4Tgkib24VKJ9c9qJsGarr5k,12110
114
+ replay/models/nn/sequential/bert4rec/lightning.py,sha256=vxAf1H1VfLqgZhOz9fxEMmw4L3wfOr_wFnWHn_vPE34,28351
115
+ replay/models/nn/sequential/bert4rec/model.py,sha256=C1AKcQ8KF0XMXERwrFneW9kg7hzc-9FIqhCc-t91F7o,17469
116
+ replay/models/nn/sequential/callbacks/__init__.py,sha256=Q7mSZ_RB6iyD7QZaBL_NJ0uh8cRfgxq7gtPHbkSyhoo,282
117
+ replay/models/nn/sequential/callbacks/prediction_callbacks.py,sha256=UtEzO9_f5Jwku9dbz7twr4o2_cV3L-viC4lQuce5l1c,10808
118
+ replay/models/nn/sequential/callbacks/validation_callback.py,sha256=ydcNkUhaFD78ogqZWySzzKg4BaPyEkaRqmLiD4qFDzM,6583
119
+ replay/models/nn/sequential/compiled/__init__.py,sha256=eSVcCaUH5cDJQRbC7K99X7uMNR-Z-KR4TmYOGKWWJCI,531
120
+ replay/models/nn/sequential/compiled/base_compiled_model.py,sha256=f4AuTyx5tufQOtOWUSEgj1cWvMZzSL7YN2Z-PtURgTY,10478
121
+ replay/models/nn/sequential/compiled/bert4rec_compiled.py,sha256=woGI3qk4J2Rb5FyaDwpSCuG-AMfyH34F6Bt5pV-wqk0,6798
122
+ replay/models/nn/sequential/compiled/sasrec_compiled.py,sha256=eCRpxPdu94KyvczYJx2jgt6xaTZ3RpNYvcfyyyTYuiA,6170
123
+ replay/models/nn/sequential/postprocessors/__init__.py,sha256=89LGzkNHukcuC2-rfpiz7vmv1zyk6MNY-8zaXrvtn0M,164
124
+ replay/models/nn/sequential/postprocessors/_base.py,sha256=Pi8vWcaiqj3XddaxbiOYd5ME7ZfIkk0GPoCgpFKdO0g,1300
125
+ replay/models/nn/sequential/postprocessors/postprocessors.py,sha256=fclLmGkJbWAnNBw-Rvc_kKQsw0rUF2jfJ7s6VF8ge4I,8021
126
+ replay/models/nn/sequential/sasrec/__init__.py,sha256=c6130lRpPkcbuGgkM7slagBIgH7Uk5zUtSzFDEwAsik,250
127
+ replay/models/nn/sequential/sasrec/dataset.py,sha256=Le_rG-MoCpWoSKdrEJOyLo3S617FLMEMI8Ix51YEQx0,9452
128
+ replay/models/nn/sequential/sasrec/lightning.py,sha256=nJthkJvgp-nBy6mtt_5PvzUFihArXYTnZdAih85F01U,27067
129
+ replay/models/nn/sequential/sasrec/model.py,sha256=xLPz2HvPkDGMaXiWcyug7auQgBG-ai37OPFb7_jvorU,27876
130
+ replay/models/optimization/__init__.py,sha256=N8xCuzu0jQGwHrIBjuTRf-ZcZuBJ6FB0d9C5a7izJQU,338
131
+ replay/models/optimization/optuna_mixin.py,sha256=pKu-Vw9l2LsDycubpdJiLkC1eE4pKrDG0T2lhUgRUB4,11960
132
+ replay/models/optimization/optuna_objective.py,sha256=UHWOJwBngPA3IRz9yAMEWPg00oyb7Wq9PXuRPYHIiLE,7538
133
+ replay/models/pop_rec.py,sha256=Ju9y2rU2vW_jFU9-W15fbbr5_ZzYGihSjSxsqKsAf0Q,4964
134
+ replay/models/query_pop_rec.py,sha256=UNsHtf3eQpJom73ZmEO5us4guI4SnCLJYTfuUpRgqes,4086
135
+ replay/models/random_rec.py,sha256=9SC012_X3sNzrAjDG1CPGhjisZb6gnv4VCW7yIMSNpk,8066
136
+ replay/models/slim.py,sha256=OAdTS64bObZujzHkq8vfP1kkoLMSWxk1KLg6lCCA0N8,4551
137
+ replay/models/thompson_sampling.py,sha256=gcjlVl1mPiEVt70y8frA762O-eCZzd3SVg1lnDRCEHk,1939
138
+ replay/models/ucb.py,sha256=b2qFgvOAZcyv5triPk18duqF_jt-ty7mypenjRLNWwQ,6952
139
+ replay/models/wilson.py,sha256=o7aUWjq3648dAfgGBoWD5Gu-HzdyobPMaH2lzCLijiA,4558
140
+ replay/models/word2vec.py,sha256=atfj6GjR_L-TdurRFr1yi7B3BicJ3ZdFxixW9RfojJs,8882
141
+ replay/nn/__init__.py,sha256=Bd_Xi9s5g1zWSjMwk50ztG9oezhs37r2L4-mfB-gEsg,256
142
+ replay/nn/agg.py,sha256=JneTgVlo00cEg5FxzIp6NvNVOXqvL45e9vsXPP_5ztg,3799
143
+ replay/nn/attention.py,sha256=RR_KsqvnrZ1ZYr51KTBA9q5gB-0sqhmakjH1JdIo9dE,7812
144
+ replay/nn/embedding.py,sha256=xY_zPpC055cTXAZ8TShUYP3ZrBUA2HQwn4dkOCKXYJ0,11876
145
+ replay/nn/ffn.py,sha256=ivOFu14289URepyEFxYov_XNYMUrINjU-2rEqoXxbnU,4618
146
+ replay/nn/head.py,sha256=csjwQrcA7M7FebgSL1tKDbjfaoni52CymQR0Zt8RhWg,2084
147
+ replay/nn/lightning/__init__.py,sha256=jHiwtYuboGUY4Of18zrkvdWD0xXJ_zuo83-XgiqxSfY,36
148
+ replay/nn/lightning/callback/__init__.py,sha256=ImNEJeIK-wJnqdkZgP8tWTDQHaS9xYqzTEf3FEM0XAw,253
149
+ replay/nn/lightning/callback/metrics_callback.py,sha256=dIu1wDtqjXH6ogFGsh2L-dpkgz7OKjtTrVbBLrI4pjg,6986
150
+ replay/nn/lightning/callback/predictions_callback.py,sha256=e9PeXNyyGz-m46FEaafgCToPEVC9T5Cb8Q4sFArnpLY,11347
151
+ replay/nn/lightning/module.py,sha256=jFvevwiriY9alZMBw6KAiRMsJv-dJ8fEVrenVRiuWeI,5246
152
+ replay/nn/lightning/optimizer.py,sha256=1tXhz9RIBHLpEQtZ1PUzCAc4mn6Q_E38zR0nf5km6U8,1778
153
+ replay/nn/lightning/postprocessor/__init__.py,sha256=LhUeOWDD5vRBDXF2tQEjvPKH1rNIlrf5KPbcV66AdtQ,77
154
+ replay/nn/lightning/postprocessor/_base.py,sha256=X0LtYItmxlXt4Sxk3cOdyIK3FG5dijQzyh5Kv6s5FjE,1592
155
+ replay/nn/lightning/postprocessor/seen_items.py,sha256=h-sfD3vmNCdS7lYvqCfqw9oPqutmaSIuZ0CIidG0Y30,2922
156
+ replay/nn/lightning/scheduler.py,sha256=CUuynPTFrKBrkpmbWR-xpfAkHZ0Vfz_THUDo3uoZi8k,2714
157
+ replay/nn/loss/__init__.py,sha256=YXAXQIN0coj8MxeK5isTGXgvMxhH5pUO6j1D3d7jl3A,471
158
+ replay/nn/loss/base.py,sha256=oD1vATWoQDi45zG9EPjg3hgDrfpr4ue_rQFfArn1dFs,8871
159
+ replay/nn/loss/bce.py,sha256=cPlxdJTBZ0b22K6V9ve4qo7xkp99CjEsnl3_vVGphqs,8373
160
+ replay/nn/loss/ce.py,sha256=jOmhLtKD_E0jX8tUfXpsmaaQVHKKiwXW9USB_GyN3ZU,13218
161
+ replay/nn/loss/login_ce.py,sha256=NER_Hbs_H3IXn_bkgwG25VQNQ6ZjjDcxq-aMI7pC2eM,16498
162
+ replay/nn/loss/logout_ce.py,sha256=KhcYyCnUzLZR1sFpxM6_QliLoxmC6MJoLkPOgf_ZYzU,10306
163
+ replay/nn/mask.py,sha256=Jbx7sulGZYfasNaD9CZzJma0cEVaDlxdpzs295507II,3329
164
+ replay/nn/normalization.py,sha256=Z86t5WCr4KfVR9qCCe-EIAwwomnIIxb11PP88WHA1JI,187
165
+ replay/nn/output.py,sha256=6uecMOMN4FGoQ-NzKGacZnlrk_9TwQswpC-x3G_DMTY,1291
166
+ replay/nn/sequential/__init__.py,sha256=jet_ueMz5Bm087JDph7ln87NID7DbCb0WENj-tjoOGg,229
167
+ replay/nn/sequential/sasrec/__init__.py,sha256=8crj-JL8xeP-cCOCnxCSVF_-R6feKhj0YRHOcaMsqrU,213
168
+ replay/nn/sequential/sasrec/agg.py,sha256=e-IkIO-MMbei2UGxTUopWvloguJoVaZiN31sXkdUVag,2004
169
+ replay/nn/sequential/sasrec/diff_transformer.py,sha256=4ehM5EMizajmWBAzmcj3CYSFl21V1R2b7RDRJlx3O4Q,4790
170
+ replay/nn/sequential/sasrec/model.py,sha256=sQ2FvfDyZ3G6PjbNME--fMboqUt66z9J8t8YYlJ9J6Q,14803
171
+ replay/nn/sequential/sasrec/transformer.py,sha256=sJf__IPnhbJWDPuFTPSbBGSSntznQtS-hJtJo3iFBkw,4037
172
+ replay/nn/sequential/twotower/__init__.py,sha256=-rEASPqKCbS55MTTgeDZ5irfWfM9or1vNTHZnJN2AcU,124
173
+ replay/nn/sequential/twotower/model.py,sha256=VxUUjldHndCkDjrXGqmxGnTi5fh8vmnr7XNBpYjsqW8,28659
174
+ replay/nn/sequential/twotower/reader.py,sha256=j4mlKx5Lf3hFnSgaxMLkuqWLZd3dkLchDI4JEuZHLGY,3674
175
+ replay/nn/transform/__init__.py,sha256=9PeaDHmftb0s1gEEgJRNWw6Bl2wfE_-lImatipaHUQ0,705
176
+ replay/nn/transform/copy.py,sha256=ZfNXbMJYTwXDMJ5T8ib9Dh5XOGLjj7gGB4NbBExFZiM,1302
177
+ replay/nn/transform/grouping.py,sha256=XOJoVBk234DI6x05Kqr7KOjLetDaLp2NMAJWHecQcsI,1384
178
+ replay/nn/transform/negative_sampling.py,sha256=R5di5-IuNtpbjcjHYcBTZYd6Lk2R5_I77PVioaL5s5w,7475
179
+ replay/nn/transform/next_token.py,sha256=UONG8_J-UxZdRCOEcz7fvU40k-hvE_h7ff014L9Ukpg,4491
180
+ replay/nn/transform/rename.py,sha256=_uD2e1UmtBRyOTVpHUnZ5xhePmClaGQsc0g7Es-rupE,1026
181
+ replay/nn/transform/reshape.py,sha256=sgswIogWHUwOVp02k13Qopn84LofqLoA4M7U1GAfmio,1359
182
+ replay/nn/transform/sequence_roll.py,sha256=7jf42SgWHU1L7SirqQWXx0h9a6VQQ29kehE4LmdUt9o,1531
183
+ replay/nn/transform/template/__init__.py,sha256=lYzAekZUXwncGR66Nq8YypplGOtL00GFfm0PalGiY5g,106
184
+ replay/nn/transform/template/sasrec.py,sha256=FoOhroe-S0JPaxIQ3Ba-3_gyslgj47RoLL2geOxNAO4,1906
185
+ replay/nn/transform/template/twotower.py,sha256=BIlbqTfKEMcyx2Ksr4qzAD0h0mdhiTLa1xcmZ2e8Ksc,896
186
+ replay/nn/transform/token_mask.py,sha256=WcalZkY2UCoNiq2mBtu8fqYFOUfqCh21XyDMgvIpeB4,2529
187
+ replay/nn/transform/trim.py,sha256=mPn6LPxu3c3yE14heMSRsDEU4h94tkFiRr62mOa3lKg,1608
188
+ replay/nn/utils.py,sha256=GumtN-QRP9ljXYti3YvuNk13e0Q92xvkYuCJBhaViCI,801
189
+ replay/preprocessing/__init__.py,sha256=c6wFPAc6lATyp0lE-ZDjHMsXyEMPKX7Usuqylv6H5XQ,597
190
+ replay/preprocessing/converter.py,sha256=JQ-4u5x0eXtswl1iH-bZITBXQov1nebnZ6XcvpD8Twk,4417
191
+ replay/preprocessing/discretizer.py,sha256=jzYqvoSVmiL-oS-ri9Om0vSDoU8bCQimjUoe7FiPfLU,27024
192
+ replay/preprocessing/filters.py,sha256=cCX8BikKNqcAGFpJkYssQkR_6tUjjktSlpZOxK1ezUw,49930
193
+ replay/preprocessing/history_based_fp.py,sha256=oEu1CkCz7xcGbPdSTHfhTe1NimnFo50Arn8qngRBgE8,18702
194
+ replay/preprocessing/label_encoder.py,sha256=puedlFGitjI_yi4uxRIR6L4Wz6oZ93gIEPeylC-jCtI,41459
195
+ replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
196
+ replay/preprocessing/utils.py,sha256=e-JRoadbeTe3Qvp_NXMZNQkmgedeR6iJLyO_82xKPd0,7109
197
+ replay/scenarios/__init__.py,sha256=XXAKEQPTLlve-0O6NPwFgahFrb4oGcIq3HaYaaGxG2E,94
198
+ replay/scenarios/fallback.py,sha256=dO3s9jqYup4rbgMaY6Z6HGm1r7SXkm7jOvNZDr5zm_U,7138
199
+ replay/splitters/__init__.py,sha256=9vhrZ8nCgq_NYJkv4wn0JYqhKURZH6Z8IyRNN1BX6AI,510
200
+ replay/splitters/base_splitter.py,sha256=zvYVEHBYrK8Y2qPv3kYachfLFwR9-kUAiU1UJSNGS8A,7749
201
+ replay/splitters/cold_user_random_splitter.py,sha256=32VgAHiwk9Emkofu1KqwGZrrFiyrYtSQ3YPdt5p_XoQ,4423
202
+ replay/splitters/k_folds.py,sha256=RDDL3gE6M5qfK5Ig-pxxJeq3O4uxsWJjLFQRRzQ2Ssg,6211
203
+ replay/splitters/last_n_splitter.py,sha256=hMWIGYFg17LioT08VBXut5Ic-w9oXsKd739cy2xuwYs,15368
204
+ replay/splitters/new_users_splitter.py,sha256=NksAdl_wL9zwHj3cY5NqrrnkOajgyUDloSsRZ9HUE48,9160
205
+ replay/splitters/random_next_n_splitter.py,sha256=aRqRe1jll7o5Hj-si-jyr341T4nXLfpX39crwVpLl-Y,8713
206
+ replay/splitters/random_splitter.py,sha256=0DO0qulT0jp_GXswmFh3BMJ7utS-z9e-r5jIrmTKGC4,2989
207
+ replay/splitters/ratio_splitter.py,sha256=rFWN-nKBYx1qKrmtYzjYf08DWFiKOCo5ZRUz-NHJFfs,17506
208
+ replay/splitters/time_splitter.py,sha256=0ZAMK26b--1wjrfzCuNVBh7gMPTa8SGf4LMEgACiUxA,9013
209
+ replay/splitters/two_stage_splitter.py,sha256=8Zn6BTJmZg04CD4l2jmil2dEu6xtglJaSS5mkotIXRc,17823
210
+ replay/utils/__init__.py,sha256=3Skc9bUISEPPMMxdUCCT_S1q-i7cAT3KT0nExe-VMrw,343
211
+ replay/utils/common.py,sha256=_sBKR1hlZavXll8NN0hyGDIEdLakccPofu8JskHpBgk,5488
212
+ replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
213
+ replay/utils/distributions.py,sha256=UuhaC9HI6HnUXW97fEd-TsyDk4JT8t7k1T_6l5FpOMs,1203
214
+ replay/utils/model_handler.py,sha256=6WRyd39B-UXTtKTHWD_ssYN1vMmkjd417bwKb50uqJY,5754
215
+ replay/utils/session_handler.py,sha256=fQo2wseow8yuzKnEXT-aYAXcQIgRbTTXp0v7g1VVi0w,5138
216
+ replay/utils/spark_utils.py,sha256=GbRp-MuUoO3Pc4chFvlmo9FskSlRLeNlC3Go5pEJ6Ok,27411
217
+ replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
218
+ replay/utils/types.py,sha256=rD9q9CqEXgF4yy512Hv2nXclvwcnfodOnhBZ1HSUI4c,1260
219
+ replay_rec-0.21.0.dist-info/METADATA,sha256=9KaxfPOyxMV7l4O3L3qy59ACnvB1-ZbhwynJGKKlXzw,13573
220
+ replay_rec-0.21.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
221
+ replay_rec-0.21.0.dist-info/licenses/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
222
+ replay_rec-0.21.0.dist-info/licenses/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
223
+ replay_rec-0.21.0.dist-info/RECORD,,
@@ -1,62 +0,0 @@
1
- """
2
- Most metrics require dataframe with recommendations
3
- and dataframe with ground truth values —
4
- which objects each user interacted with.
5
-
6
- - recommendations (Union[pandas.DataFrame, spark.DataFrame]):
7
- predictions of a recommender system,
8
- DataFrame with columns ``[user_id, item_id, relevance]``
9
- - ground_truth (Union[pandas.DataFrame, spark.DataFrame]):
10
- test data, DataFrame with columns
11
- ``[user_id, item_id, timestamp, relevance]``
12
-
13
- Metric is calculated for all users, presented in ``ground_truth``
14
- for accurate metric calculation in case when the recommender system generated
15
- recommendation not for all users. It is assumed, that all users,
16
- we want to calculate metric for, have positive interactions.
17
-
18
- But if we have users, who observed the recommendations, but have not responded,
19
- those users will be ignored and metric will be overestimated.
20
- For such case we propose additional optional parameter ``ground_truth_users``,
21
- the dataframe with all users, which should be considered during the metric calculation.
22
-
23
- - ground_truth_users (Optional[Union[pandas.DataFrame, spark.DataFrame]]):
24
- full list of users to calculate metric for, DataFrame with ``user_id`` column
25
-
26
- Every metric is calculated using top ``K`` items for each user.
27
- It is also possible to calculate metrics
28
- using multiple values for ``K`` simultaneously.
29
- In this case the result will be a dictionary and not a number.
30
-
31
- Make sure your recommendations do not contain user-item duplicates
32
- as duplicates could lead to the wrong calculation results.
33
-
34
- - k (Union[Iterable[int], int]):
35
- a single number or a list, specifying the
36
- truncation length for recommendation list for each user
37
-
38
- By default, metrics are averaged by users,
39
- but you can alternatively use method ``metric.median``.
40
- Also, you can get the lower bound
41
- of ``conf_interval`` for a given ``alpha``.
42
-
43
- Diversity metrics require extra parameters on initialization stage,
44
- but do not use ``ground_truth`` parameter.
45
-
46
- For each metric, a formula for its calculation is given, because this is
47
- important for the correct comparison of algorithms, as mentioned in our
48
- `article <https://arxiv.org/abs/2206.12858>`_.
49
- """
50
-
51
- from replay.experimental.metrics.base_metric import Metric, NCISMetric
52
- from replay.experimental.metrics.coverage import Coverage
53
- from replay.experimental.metrics.hitrate import HitRate
54
- from replay.experimental.metrics.map import MAP
55
- from replay.experimental.metrics.mrr import MRR
56
- from replay.experimental.metrics.ncis_precision import NCISPrecision
57
- from replay.experimental.metrics.ndcg import NDCG
58
- from replay.experimental.metrics.precision import Precision
59
- from replay.experimental.metrics.recall import Recall
60
- from replay.experimental.metrics.rocauc import RocAuc
61
- from replay.experimental.metrics.surprisal import Surprisal
62
- from replay.experimental.metrics.unexpectedness import Unexpectedness