replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -61
- replay/experimental/metrics/base_metric.py +0 -661
- replay/experimental/metrics/coverage.py +0 -117
- replay/experimental/metrics/experiment.py +0 -200
- replay/experimental/metrics/hitrate.py +0 -27
- replay/experimental/metrics/map.py +0 -31
- replay/experimental/metrics/mrr.py +0 -19
- replay/experimental/metrics/ncis_precision.py +0 -32
- replay/experimental/metrics/ndcg.py +0 -50
- replay/experimental/metrics/precision.py +0 -23
- replay/experimental/metrics/recall.py +0 -26
- replay/experimental/metrics/rocauc.py +0 -50
- replay/experimental/metrics/surprisal.py +0 -102
- replay/experimental/metrics/unexpectedness.py +0 -74
- replay/experimental/models/__init__.py +0 -10
- replay/experimental/models/admm_slim.py +0 -216
- replay/experimental/models/base_neighbour_rec.py +0 -222
- replay/experimental/models/base_rec.py +0 -1361
- replay/experimental/models/base_torch_rec.py +0 -247
- replay/experimental/models/cql.py +0 -468
- replay/experimental/models/ddpg.py +0 -1007
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -193
- replay/experimental/models/dt4rec/gpt1.py +0 -411
- replay/experimental/models/dt4rec/trainer.py +0 -128
- replay/experimental/models/dt4rec/utils.py +0 -274
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
- replay/experimental/models/implicit_wrap.py +0 -138
- replay/experimental/models/lightfm_wrap.py +0 -327
- replay/experimental/models/mult_vae.py +0 -374
- replay/experimental/models/neuromf.py +0 -462
- replay/experimental/models/scala_als.py +0 -311
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -58
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -929
- replay/experimental/preprocessing/padder.py +0 -231
- replay/experimental/preprocessing/sequence_generator.py +0 -218
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
- replay/experimental/scenarios/two_stages/reranker.py +0 -116
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -213
- replay/experimental/utils/session_handler.py +0 -47
- replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
- replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
- {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
|
@@ -1,247 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from typing import Any, Dict, Optional
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import torch
|
|
6
|
-
from torch import nn
|
|
7
|
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
8
|
-
from torch.optim.optimizer import Optimizer
|
|
9
|
-
from torch.utils.data import DataLoader
|
|
10
|
-
|
|
11
|
-
from replay.data import get_schema
|
|
12
|
-
from replay.experimental.models.base_rec import Recommender
|
|
13
|
-
from replay.experimental.utils.session_handler import State
|
|
14
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
15
|
-
|
|
16
|
-
if PYSPARK_AVAILABLE:
|
|
17
|
-
from pyspark.sql import functions as sf
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class TorchRecommender(Recommender):
|
|
21
|
-
"""Base class for neural recommenders"""
|
|
22
|
-
|
|
23
|
-
model: Any
|
|
24
|
-
device: torch.device
|
|
25
|
-
|
|
26
|
-
def __init__(self):
|
|
27
|
-
self.logger.info(
|
|
28
|
-
"The model is neural network with non-distributed training"
|
|
29
|
-
)
|
|
30
|
-
self.checkpoint_path = State().session.conf.get("spark.local.dir")
|
|
31
|
-
self.device = State().device
|
|
32
|
-
|
|
33
|
-
def _run_train_step(self, batch, optimizer):
|
|
34
|
-
self.model.train()
|
|
35
|
-
optimizer.zero_grad()
|
|
36
|
-
model_result = self._batch_pass(batch, self.model)
|
|
37
|
-
loss = self._loss(**model_result)
|
|
38
|
-
loss.backward()
|
|
39
|
-
optimizer.step()
|
|
40
|
-
return loss.item()
|
|
41
|
-
|
|
42
|
-
def _run_validation(
|
|
43
|
-
self, valid_data_loader: DataLoader, epoch: int
|
|
44
|
-
) -> float:
|
|
45
|
-
self.model.eval()
|
|
46
|
-
valid_loss = 0
|
|
47
|
-
with torch.no_grad():
|
|
48
|
-
for batch in valid_data_loader:
|
|
49
|
-
model_result = self._batch_pass(batch, self.model)
|
|
50
|
-
valid_loss += self._loss(**model_result)
|
|
51
|
-
valid_loss /= len(valid_data_loader)
|
|
52
|
-
valid_debug_message = f"""Epoch[{epoch}] validation
|
|
53
|
-
average loss: {valid_loss:.5f}"""
|
|
54
|
-
self.logger.debug(valid_debug_message)
|
|
55
|
-
return valid_loss.item()
|
|
56
|
-
|
|
57
|
-
# pylint: disable=too-many-arguments
|
|
58
|
-
def train(
|
|
59
|
-
self,
|
|
60
|
-
train_data_loader: DataLoader,
|
|
61
|
-
valid_data_loader: DataLoader,
|
|
62
|
-
optimizer: Optimizer,
|
|
63
|
-
lr_scheduler: ReduceLROnPlateau,
|
|
64
|
-
epochs: int,
|
|
65
|
-
model_name: str,
|
|
66
|
-
) -> None:
|
|
67
|
-
"""
|
|
68
|
-
Run training loop
|
|
69
|
-
:param train_data_loader: data loader for training
|
|
70
|
-
:param valid_data_loader: data loader for validation
|
|
71
|
-
:param optimizer: optimizer
|
|
72
|
-
:param lr_scheduler: scheduler used to decrease learning rate
|
|
73
|
-
:param lr_scheduler: scheduler used to decrease learning rate
|
|
74
|
-
:param epochs: num training epochs
|
|
75
|
-
:param model_name: model name for checkpoint saving
|
|
76
|
-
:return:
|
|
77
|
-
"""
|
|
78
|
-
best_valid_loss = np.inf
|
|
79
|
-
for epoch in range(epochs):
|
|
80
|
-
for batch in train_data_loader:
|
|
81
|
-
train_loss = self._run_train_step(batch, optimizer)
|
|
82
|
-
|
|
83
|
-
train_debug_message = f"""Epoch[{epoch}] current loss:
|
|
84
|
-
{train_loss:.5f}"""
|
|
85
|
-
self.logger.debug(train_debug_message)
|
|
86
|
-
|
|
87
|
-
valid_loss = self._run_validation(valid_data_loader, epoch)
|
|
88
|
-
lr_scheduler.step(valid_loss)
|
|
89
|
-
|
|
90
|
-
if valid_loss < best_valid_loss:
|
|
91
|
-
best_checkpoint = "/".join(
|
|
92
|
-
[
|
|
93
|
-
self.checkpoint_path,
|
|
94
|
-
f"/best_{model_name}_{epoch+1}_loss={valid_loss}.pt",
|
|
95
|
-
]
|
|
96
|
-
)
|
|
97
|
-
self._save_model(best_checkpoint)
|
|
98
|
-
best_valid_loss = valid_loss
|
|
99
|
-
self._load_model(best_checkpoint)
|
|
100
|
-
|
|
101
|
-
@abstractmethod
|
|
102
|
-
def _batch_pass(self, batch, model) -> Dict[str, Any]:
|
|
103
|
-
"""
|
|
104
|
-
Apply model to a single batch.
|
|
105
|
-
|
|
106
|
-
:param batch: data batch
|
|
107
|
-
:param model: model object
|
|
108
|
-
:return: dictionary used to calculate loss.
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
-
@abstractmethod
|
|
112
|
-
def _loss(self, **kwargs) -> torch.Tensor:
|
|
113
|
-
"""
|
|
114
|
-
Returns loss value
|
|
115
|
-
|
|
116
|
-
:param **kwargs: dictionary used to calculate loss
|
|
117
|
-
:return: 1x1 tensor
|
|
118
|
-
"""
|
|
119
|
-
|
|
120
|
-
# pylint: disable=too-many-arguments
|
|
121
|
-
# pylint: disable=too-many-locals
|
|
122
|
-
def _predict(
|
|
123
|
-
self,
|
|
124
|
-
log: SparkDataFrame,
|
|
125
|
-
k: int,
|
|
126
|
-
users: SparkDataFrame,
|
|
127
|
-
items: SparkDataFrame,
|
|
128
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
129
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
130
|
-
filter_seen_items: bool = True,
|
|
131
|
-
) -> SparkDataFrame:
|
|
132
|
-
items_consider_in_pred = items.toPandas()["item_idx"].values
|
|
133
|
-
items_count = self._item_dim
|
|
134
|
-
model = self.model.cpu()
|
|
135
|
-
agg_fn = self._predict_by_user
|
|
136
|
-
|
|
137
|
-
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
138
|
-
return agg_fn(
|
|
139
|
-
pandas_df, model, items_consider_in_pred, k, items_count
|
|
140
|
-
)[["user_idx", "item_idx", "relevance"]]
|
|
141
|
-
|
|
142
|
-
self.logger.debug("Predict started")
|
|
143
|
-
# do not apply map on cold users for MultVAE predict
|
|
144
|
-
join_type = "inner" if str(self) == "MultVAE" else "left"
|
|
145
|
-
rec_schema = get_schema(
|
|
146
|
-
query_column="user_idx",
|
|
147
|
-
item_column="item_idx",
|
|
148
|
-
rating_column="relevance",
|
|
149
|
-
has_timestamp=False,
|
|
150
|
-
)
|
|
151
|
-
recs = (
|
|
152
|
-
users.join(log, how=join_type, on="user_idx")
|
|
153
|
-
.select("user_idx", "item_idx")
|
|
154
|
-
.groupby("user_idx")
|
|
155
|
-
.applyInPandas(grouped_map, rec_schema)
|
|
156
|
-
)
|
|
157
|
-
return recs
|
|
158
|
-
|
|
159
|
-
def _predict_pairs(
|
|
160
|
-
self,
|
|
161
|
-
pairs: SparkDataFrame,
|
|
162
|
-
log: Optional[SparkDataFrame] = None,
|
|
163
|
-
user_features: Optional[SparkDataFrame] = None,
|
|
164
|
-
item_features: Optional[SparkDataFrame] = None,
|
|
165
|
-
) -> SparkDataFrame:
|
|
166
|
-
items_count = self._item_dim
|
|
167
|
-
model = self.model.cpu()
|
|
168
|
-
agg_fn = self._predict_by_user_pairs
|
|
169
|
-
users = pairs.select("user_idx").distinct()
|
|
170
|
-
|
|
171
|
-
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
172
|
-
return agg_fn(pandas_df, model, items_count)[
|
|
173
|
-
["user_idx", "item_idx", "relevance"]
|
|
174
|
-
]
|
|
175
|
-
|
|
176
|
-
self.logger.debug("Calculate relevance for user-item pairs")
|
|
177
|
-
user_history = (
|
|
178
|
-
users.join(log, how="inner", on="user_idx")
|
|
179
|
-
.groupBy("user_idx")
|
|
180
|
-
.agg(sf.collect_list("item_idx").alias("item_idx_history"))
|
|
181
|
-
)
|
|
182
|
-
user_pairs = pairs.groupBy("user_idx").agg(
|
|
183
|
-
sf.collect_list("item_idx").alias("item_idx_to_pred")
|
|
184
|
-
)
|
|
185
|
-
full_df = user_pairs.join(user_history, on="user_idx", how="inner")
|
|
186
|
-
|
|
187
|
-
rec_schema = get_schema(
|
|
188
|
-
query_column="user_idx",
|
|
189
|
-
item_column="item_idx",
|
|
190
|
-
rating_column="relevance",
|
|
191
|
-
has_timestamp=False,
|
|
192
|
-
)
|
|
193
|
-
recs = full_df.groupby("user_idx").applyInPandas(
|
|
194
|
-
grouped_map, rec_schema
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
return recs
|
|
198
|
-
|
|
199
|
-
@staticmethod
|
|
200
|
-
@abstractmethod
|
|
201
|
-
def _predict_by_user(
|
|
202
|
-
pandas_df: PandasDataFrame,
|
|
203
|
-
model: nn.Module,
|
|
204
|
-
items_np: np.ndarray,
|
|
205
|
-
k: int,
|
|
206
|
-
item_count: int,
|
|
207
|
-
) -> PandasDataFrame:
|
|
208
|
-
"""
|
|
209
|
-
Calculate predictions.
|
|
210
|
-
|
|
211
|
-
:param pandas_df: DataFrame with user-item interactions ``[user_idx, item_idx]``
|
|
212
|
-
:param model: trained model
|
|
213
|
-
:param items_np: items available for recommendations
|
|
214
|
-
:param k: length of recommendation list
|
|
215
|
-
:param item_count: total number of items
|
|
216
|
-
:return: DataFrame ``[user_idx , item_idx , relevance]``
|
|
217
|
-
"""
|
|
218
|
-
|
|
219
|
-
@staticmethod
|
|
220
|
-
@abstractmethod
|
|
221
|
-
def _predict_by_user_pairs(
|
|
222
|
-
pandas_df: PandasDataFrame,
|
|
223
|
-
model: nn.Module,
|
|
224
|
-
item_count: int,
|
|
225
|
-
) -> PandasDataFrame:
|
|
226
|
-
"""
|
|
227
|
-
Get relevance for provided pairs
|
|
228
|
-
|
|
229
|
-
:param pandas_df: DataFrame with rated items and items that need prediction
|
|
230
|
-
``[user_idx, item_idx_history, item_idx_to_pred]``
|
|
231
|
-
:param model: trained model
|
|
232
|
-
:param item_count: total number of items
|
|
233
|
-
:return: DataFrame ``[user_idx , item_idx , relevance]``
|
|
234
|
-
"""
|
|
235
|
-
|
|
236
|
-
def load_model(self, path: str) -> None:
|
|
237
|
-
"""
|
|
238
|
-
Load model from file
|
|
239
|
-
|
|
240
|
-
:param path: path to model
|
|
241
|
-
:return:
|
|
242
|
-
"""
|
|
243
|
-
self.logger.debug("-- Loading model from file")
|
|
244
|
-
self.model.load_state_dict(torch.load(path))
|
|
245
|
-
|
|
246
|
-
def _save_model(self, path: str) -> None:
|
|
247
|
-
torch.save(self.model.state_dict(), path)
|