replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,247 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Dict, Optional
3
-
4
- import numpy as np
5
- import torch
6
- from torch import nn
7
- from torch.optim.lr_scheduler import ReduceLROnPlateau
8
- from torch.optim.optimizer import Optimizer
9
- from torch.utils.data import DataLoader
10
-
11
- from replay.data import get_schema
12
- from replay.experimental.models.base_rec import Recommender
13
- from replay.experimental.utils.session_handler import State
14
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
15
-
16
- if PYSPARK_AVAILABLE:
17
- from pyspark.sql import functions as sf
18
-
19
-
20
- class TorchRecommender(Recommender):
21
- """Base class for neural recommenders"""
22
-
23
- model: Any
24
- device: torch.device
25
-
26
- def __init__(self):
27
- self.logger.info(
28
- "The model is neural network with non-distributed training"
29
- )
30
- self.checkpoint_path = State().session.conf.get("spark.local.dir")
31
- self.device = State().device
32
-
33
- def _run_train_step(self, batch, optimizer):
34
- self.model.train()
35
- optimizer.zero_grad()
36
- model_result = self._batch_pass(batch, self.model)
37
- loss = self._loss(**model_result)
38
- loss.backward()
39
- optimizer.step()
40
- return loss.item()
41
-
42
- def _run_validation(
43
- self, valid_data_loader: DataLoader, epoch: int
44
- ) -> float:
45
- self.model.eval()
46
- valid_loss = 0
47
- with torch.no_grad():
48
- for batch in valid_data_loader:
49
- model_result = self._batch_pass(batch, self.model)
50
- valid_loss += self._loss(**model_result)
51
- valid_loss /= len(valid_data_loader)
52
- valid_debug_message = f"""Epoch[{epoch}] validation
53
- average loss: {valid_loss:.5f}"""
54
- self.logger.debug(valid_debug_message)
55
- return valid_loss.item()
56
-
57
- # pylint: disable=too-many-arguments
58
- def train(
59
- self,
60
- train_data_loader: DataLoader,
61
- valid_data_loader: DataLoader,
62
- optimizer: Optimizer,
63
- lr_scheduler: ReduceLROnPlateau,
64
- epochs: int,
65
- model_name: str,
66
- ) -> None:
67
- """
68
- Run training loop
69
- :param train_data_loader: data loader for training
70
- :param valid_data_loader: data loader for validation
71
- :param optimizer: optimizer
72
- :param lr_scheduler: scheduler used to decrease learning rate
73
- :param lr_scheduler: scheduler used to decrease learning rate
74
- :param epochs: num training epochs
75
- :param model_name: model name for checkpoint saving
76
- :return:
77
- """
78
- best_valid_loss = np.inf
79
- for epoch in range(epochs):
80
- for batch in train_data_loader:
81
- train_loss = self._run_train_step(batch, optimizer)
82
-
83
- train_debug_message = f"""Epoch[{epoch}] current loss:
84
- {train_loss:.5f}"""
85
- self.logger.debug(train_debug_message)
86
-
87
- valid_loss = self._run_validation(valid_data_loader, epoch)
88
- lr_scheduler.step(valid_loss)
89
-
90
- if valid_loss < best_valid_loss:
91
- best_checkpoint = "/".join(
92
- [
93
- self.checkpoint_path,
94
- f"/best_{model_name}_{epoch+1}_loss={valid_loss}.pt",
95
- ]
96
- )
97
- self._save_model(best_checkpoint)
98
- best_valid_loss = valid_loss
99
- self._load_model(best_checkpoint)
100
-
101
- @abstractmethod
102
- def _batch_pass(self, batch, model) -> Dict[str, Any]:
103
- """
104
- Apply model to a single batch.
105
-
106
- :param batch: data batch
107
- :param model: model object
108
- :return: dictionary used to calculate loss.
109
- """
110
-
111
- @abstractmethod
112
- def _loss(self, **kwargs) -> torch.Tensor:
113
- """
114
- Returns loss value
115
-
116
- :param **kwargs: dictionary used to calculate loss
117
- :return: 1x1 tensor
118
- """
119
-
120
- # pylint: disable=too-many-arguments
121
- # pylint: disable=too-many-locals
122
- def _predict(
123
- self,
124
- log: SparkDataFrame,
125
- k: int,
126
- users: SparkDataFrame,
127
- items: SparkDataFrame,
128
- user_features: Optional[SparkDataFrame] = None,
129
- item_features: Optional[SparkDataFrame] = None,
130
- filter_seen_items: bool = True,
131
- ) -> SparkDataFrame:
132
- items_consider_in_pred = items.toPandas()["item_idx"].values
133
- items_count = self._item_dim
134
- model = self.model.cpu()
135
- agg_fn = self._predict_by_user
136
-
137
- def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
138
- return agg_fn(
139
- pandas_df, model, items_consider_in_pred, k, items_count
140
- )[["user_idx", "item_idx", "relevance"]]
141
-
142
- self.logger.debug("Predict started")
143
- # do not apply map on cold users for MultVAE predict
144
- join_type = "inner" if str(self) == "MultVAE" else "left"
145
- rec_schema = get_schema(
146
- query_column="user_idx",
147
- item_column="item_idx",
148
- rating_column="relevance",
149
- has_timestamp=False,
150
- )
151
- recs = (
152
- users.join(log, how=join_type, on="user_idx")
153
- .select("user_idx", "item_idx")
154
- .groupby("user_idx")
155
- .applyInPandas(grouped_map, rec_schema)
156
- )
157
- return recs
158
-
159
- def _predict_pairs(
160
- self,
161
- pairs: SparkDataFrame,
162
- log: Optional[SparkDataFrame] = None,
163
- user_features: Optional[SparkDataFrame] = None,
164
- item_features: Optional[SparkDataFrame] = None,
165
- ) -> SparkDataFrame:
166
- items_count = self._item_dim
167
- model = self.model.cpu()
168
- agg_fn = self._predict_by_user_pairs
169
- users = pairs.select("user_idx").distinct()
170
-
171
- def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
172
- return agg_fn(pandas_df, model, items_count)[
173
- ["user_idx", "item_idx", "relevance"]
174
- ]
175
-
176
- self.logger.debug("Calculate relevance for user-item pairs")
177
- user_history = (
178
- users.join(log, how="inner", on="user_idx")
179
- .groupBy("user_idx")
180
- .agg(sf.collect_list("item_idx").alias("item_idx_history"))
181
- )
182
- user_pairs = pairs.groupBy("user_idx").agg(
183
- sf.collect_list("item_idx").alias("item_idx_to_pred")
184
- )
185
- full_df = user_pairs.join(user_history, on="user_idx", how="inner")
186
-
187
- rec_schema = get_schema(
188
- query_column="user_idx",
189
- item_column="item_idx",
190
- rating_column="relevance",
191
- has_timestamp=False,
192
- )
193
- recs = full_df.groupby("user_idx").applyInPandas(
194
- grouped_map, rec_schema
195
- )
196
-
197
- return recs
198
-
199
- @staticmethod
200
- @abstractmethod
201
- def _predict_by_user(
202
- pandas_df: PandasDataFrame,
203
- model: nn.Module,
204
- items_np: np.ndarray,
205
- k: int,
206
- item_count: int,
207
- ) -> PandasDataFrame:
208
- """
209
- Calculate predictions.
210
-
211
- :param pandas_df: DataFrame with user-item interactions ``[user_idx, item_idx]``
212
- :param model: trained model
213
- :param items_np: items available for recommendations
214
- :param k: length of recommendation list
215
- :param item_count: total number of items
216
- :return: DataFrame ``[user_idx , item_idx , relevance]``
217
- """
218
-
219
- @staticmethod
220
- @abstractmethod
221
- def _predict_by_user_pairs(
222
- pandas_df: PandasDataFrame,
223
- model: nn.Module,
224
- item_count: int,
225
- ) -> PandasDataFrame:
226
- """
227
- Get relevance for provided pairs
228
-
229
- :param pandas_df: DataFrame with rated items and items that need prediction
230
- ``[user_idx, item_idx_history, item_idx_to_pred]``
231
- :param model: trained model
232
- :param item_count: total number of items
233
- :return: DataFrame ``[user_idx , item_idx , relevance]``
234
- """
235
-
236
- def load_model(self, path: str) -> None:
237
- """
238
- Load model from file
239
-
240
- :param path: path to model
241
- :return:
242
- """
243
- self.logger.debug("-- Loading model from file")
244
- self.model.load_state_dict(torch.load(path))
245
-
246
- def _save_model(self, path: str) -> None:
247
- torch.save(self.model.state_dict(), path)