replay-rec 0.20.1rc0__py3-none-any.whl → 0.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/METADATA +18 -12
  3. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/RECORD +6 -61
  4. replay/experimental/__init__.py +0 -0
  5. replay/experimental/metrics/__init__.py +0 -62
  6. replay/experimental/metrics/base_metric.py +0 -603
  7. replay/experimental/metrics/coverage.py +0 -97
  8. replay/experimental/metrics/experiment.py +0 -175
  9. replay/experimental/metrics/hitrate.py +0 -26
  10. replay/experimental/metrics/map.py +0 -30
  11. replay/experimental/metrics/mrr.py +0 -18
  12. replay/experimental/metrics/ncis_precision.py +0 -31
  13. replay/experimental/metrics/ndcg.py +0 -49
  14. replay/experimental/metrics/precision.py +0 -22
  15. replay/experimental/metrics/recall.py +0 -25
  16. replay/experimental/metrics/rocauc.py +0 -49
  17. replay/experimental/metrics/surprisal.py +0 -90
  18. replay/experimental/metrics/unexpectedness.py +0 -76
  19. replay/experimental/models/__init__.py +0 -50
  20. replay/experimental/models/admm_slim.py +0 -257
  21. replay/experimental/models/base_neighbour_rec.py +0 -200
  22. replay/experimental/models/base_rec.py +0 -1386
  23. replay/experimental/models/base_torch_rec.py +0 -234
  24. replay/experimental/models/cql.py +0 -454
  25. replay/experimental/models/ddpg.py +0 -932
  26. replay/experimental/models/dt4rec/__init__.py +0 -0
  27. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  28. replay/experimental/models/dt4rec/gpt1.py +0 -401
  29. replay/experimental/models/dt4rec/trainer.py +0 -127
  30. replay/experimental/models/dt4rec/utils.py +0 -264
  31. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  32. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  33. replay/experimental/models/hierarchical_recommender.py +0 -331
  34. replay/experimental/models/implicit_wrap.py +0 -131
  35. replay/experimental/models/lightfm_wrap.py +0 -303
  36. replay/experimental/models/mult_vae.py +0 -332
  37. replay/experimental/models/neural_ts.py +0 -986
  38. replay/experimental/models/neuromf.py +0 -406
  39. replay/experimental/models/scala_als.py +0 -293
  40. replay/experimental/models/u_lin_ucb.py +0 -115
  41. replay/experimental/nn/data/__init__.py +0 -1
  42. replay/experimental/nn/data/schema_builder.py +0 -102
  43. replay/experimental/preprocessing/__init__.py +0 -3
  44. replay/experimental/preprocessing/data_preparator.py +0 -839
  45. replay/experimental/preprocessing/padder.py +0 -229
  46. replay/experimental/preprocessing/sequence_generator.py +0 -208
  47. replay/experimental/scenarios/__init__.py +0 -1
  48. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  49. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  50. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  51. replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  52. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  53. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  54. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  55. replay/experimental/utils/__init__.py +0 -0
  56. replay/experimental/utils/logger.py +0 -24
  57. replay/experimental/utils/model_handler.py +0 -186
  58. replay/experimental/utils/session_handler.py +0 -44
  59. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/WHEEL +0 -0
  60. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/licenses/LICENSE +0 -0
  61. {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/licenses/NOTICE +0 -0
@@ -1,303 +0,0 @@
1
- import os
2
- from os.path import join
3
- from typing import Optional
4
-
5
- import numpy as np
6
- from scipy.sparse import csr_matrix, diags, hstack
7
- from sklearn.preprocessing import MinMaxScaler
8
-
9
- from replay.data import get_schema
10
- from replay.experimental.models.base_rec import HybridRecommender
11
- from replay.experimental.utils.session_handler import State
12
- from replay.preprocessing import CSRConverter
13
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
14
- from replay.utils.spark_utils import check_numeric, load_pickled_from_parquet, save_picklable_to_parquet
15
-
16
- if PYSPARK_AVAILABLE:
17
- import pyspark.sql.functions as sf
18
-
19
-
20
- class LightFMWrap(HybridRecommender):
21
- """Wrapper for LightFM."""
22
-
23
- epochs: int = 10
24
- _search_space = {
25
- "loss": {
26
- "type": "categorical",
27
- "args": ["logistic", "bpr", "warp", "warp-kos"],
28
- },
29
- "no_components": {"type": "loguniform_int", "args": [8, 512]},
30
- }
31
- user_feat_scaler: Optional[MinMaxScaler] = None
32
- item_feat_scaler: Optional[MinMaxScaler] = None
33
-
34
- def __init__(
35
- self,
36
- no_components: int = 128,
37
- loss: str = "warp",
38
- random_state: Optional[int] = None,
39
- ):
40
- np.random.seed(42)
41
- self.no_components = no_components
42
- self.loss = loss
43
- self.random_state = random_state
44
- cpu_count = os.cpu_count()
45
- self.num_threads = cpu_count if cpu_count is not None else 1
46
-
47
- @property
48
- def _init_args(self):
49
- return {
50
- "no_components": self.no_components,
51
- "loss": self.loss,
52
- "random_state": self.random_state,
53
- }
54
-
55
- def _save_model(self, path: str):
56
- save_picklable_to_parquet(self.model, join(path, "model"))
57
- save_picklable_to_parquet(self.user_feat_scaler, join(path, "user_feat_scaler"))
58
- save_picklable_to_parquet(self.item_feat_scaler, join(path, "item_feat_scaler"))
59
-
60
- def _load_model(self, path: str):
61
- self.model = load_pickled_from_parquet(join(path, "model"))
62
- self.user_feat_scaler = load_pickled_from_parquet(join(path, "user_feat_scaler"))
63
- self.item_feat_scaler = load_pickled_from_parquet(join(path, "item_feat_scaler"))
64
-
65
- def _feature_table_to_csr(
66
- self,
67
- log_ids_list: SparkDataFrame,
68
- feature_table: Optional[SparkDataFrame] = None,
69
- ) -> Optional[csr_matrix]:
70
- """
71
- Transform features to sparse matrix
72
- Matrix consists of two parts:
73
- 1) Left one is a ohe-hot encoding of user and item ids.
74
- Matrix size is: number of users or items * number of user or items in fit.
75
- Cold users and items are represented with empty strings
76
- 2) Right one is a numerical features, passed with feature_table.
77
- MinMaxScaler is applied per column, and then value is divided by the row sum.
78
-
79
- :param feature_table: dataframe with ``user_idx`` or ``item_idx``,
80
- other columns are features.
81
- :param log_ids_list: dataframe with ``user_idx`` or ``item_idx``,
82
- containing unique ids from log.
83
- :returns: feature matrix
84
- """
85
-
86
- if feature_table is None:
87
- return None
88
-
89
- check_numeric(feature_table)
90
- log_ids_list = log_ids_list.distinct()
91
- entity = "item" if "item_idx" in feature_table.columns else "user"
92
- idx_col_name = f"{entity}_idx"
93
-
94
- # filter features by log
95
- feature_table = feature_table.join(log_ids_list, on=idx_col_name, how="inner")
96
-
97
- fit_dim = getattr(self, f"_{entity}_dim")
98
- matrix_height = max(
99
- fit_dim,
100
- log_ids_list.select(sf.max(idx_col_name)).first()[0] + 1,
101
- )
102
- if not feature_table.rdd.isEmpty():
103
- matrix_height = max(
104
- matrix_height,
105
- feature_table.select(sf.max(idx_col_name)).first()[0] + 1,
106
- )
107
-
108
- features_np = (
109
- feature_table.select(
110
- idx_col_name,
111
- # first column contains id, next contain features
112
- *sorted(set(feature_table.columns).difference({idx_col_name})),
113
- )
114
- .toPandas()
115
- .to_numpy()
116
- )
117
- entities_ids = features_np[:, 0]
118
- features_np = features_np[:, 1:]
119
- number_of_features = features_np.shape[1]
120
-
121
- all_ids_list = log_ids_list.toPandas().to_numpy().ravel()
122
- entities_seen_in_fit = all_ids_list[all_ids_list < fit_dim]
123
-
124
- entity_id_features = csr_matrix(
125
- (
126
- [1.0] * entities_seen_in_fit.shape[0],
127
- (entities_seen_in_fit, entities_seen_in_fit),
128
- ),
129
- shape=(matrix_height, fit_dim),
130
- )
131
-
132
- scaler_name = f"{entity}_feat_scaler"
133
- if getattr(self, scaler_name) is None:
134
- if not features_np.size:
135
- msg = f"features for {entity}s from log are absent"
136
- raise ValueError(msg)
137
- setattr(self, scaler_name, MinMaxScaler().fit(features_np))
138
-
139
- if features_np.size:
140
- features_np = getattr(self, scaler_name).transform(features_np)
141
- sparse_features = csr_matrix(
142
- (
143
- features_np.ravel(),
144
- (
145
- np.repeat(entities_ids, number_of_features),
146
- np.tile(
147
- np.arange(number_of_features),
148
- entities_ids.shape[0],
149
- ),
150
- ),
151
- ),
152
- shape=(matrix_height, number_of_features),
153
- )
154
-
155
- else:
156
- sparse_features = csr_matrix((matrix_height, number_of_features))
157
-
158
- concat_features = hstack([entity_id_features, sparse_features])
159
- concat_features_sum = concat_features.sum(axis=1).A.ravel()
160
- mask = concat_features_sum != 0.0
161
- concat_features_sum[mask] = 1.0 / concat_features_sum[mask]
162
- return diags(concat_features_sum, format="csr") @ concat_features
163
-
164
- def _fit(
165
- self,
166
- log: SparkDataFrame,
167
- user_features: Optional[SparkDataFrame] = None,
168
- item_features: Optional[SparkDataFrame] = None,
169
- ) -> None:
170
- from lightfm import LightFM
171
-
172
- self.user_feat_scaler = None
173
- self.item_feat_scaler = None
174
-
175
- interactions_matrix = CSRConverter(
176
- first_dim_column="user_idx",
177
- second_dim_column="item_idx",
178
- data_column="relevance",
179
- row_count=self._user_dim,
180
- column_count=self._item_dim,
181
- ).transform(log)
182
- csr_item_features = self._feature_table_to_csr(log.select("item_idx").distinct(), item_features)
183
- csr_user_features = self._feature_table_to_csr(log.select("user_idx").distinct(), user_features)
184
-
185
- if user_features is not None:
186
- self.can_predict_cold_users = True
187
- if item_features is not None:
188
- self.can_predict_cold_items = True
189
-
190
- self.model = LightFM(
191
- loss=self.loss,
192
- no_components=self.no_components,
193
- random_state=self.random_state,
194
- ).fit(
195
- interactions=interactions_matrix,
196
- epochs=self.epochs,
197
- num_threads=self.num_threads,
198
- item_features=csr_item_features,
199
- user_features=csr_user_features,
200
- )
201
-
202
- def _predict_selected_pairs(
203
- self,
204
- pairs: SparkDataFrame,
205
- user_features: Optional[SparkDataFrame] = None,
206
- item_features: Optional[SparkDataFrame] = None,
207
- ):
208
- def predict_by_user(pandas_df: PandasDataFrame) -> PandasDataFrame:
209
- pandas_df["relevance"] = model.predict(
210
- user_ids=pandas_df["user_idx"].to_numpy(),
211
- item_ids=pandas_df["item_idx"].to_numpy(),
212
- item_features=csr_item_features,
213
- user_features=csr_user_features,
214
- )
215
- return pandas_df
216
-
217
- model = self.model
218
-
219
- if self.can_predict_cold_users and user_features is None:
220
- msg = "User features are missing for predict"
221
- raise ValueError(msg)
222
- if self.can_predict_cold_items and item_features is None:
223
- msg = "Item features are missing for predict"
224
- raise ValueError(msg)
225
-
226
- csr_item_features = self._feature_table_to_csr(pairs.select("item_idx").distinct(), item_features)
227
- csr_user_features = self._feature_table_to_csr(pairs.select("user_idx").distinct(), user_features)
228
- rec_schema = get_schema(
229
- query_column="user_idx",
230
- item_column="item_idx",
231
- rating_column="relevance",
232
- has_timestamp=False,
233
- )
234
- return pairs.groupby("user_idx").applyInPandas(predict_by_user, rec_schema)
235
-
236
- def _predict(
237
- self,
238
- log: SparkDataFrame, # noqa: ARG002
239
- k: int, # noqa: ARG002
240
- users: SparkDataFrame,
241
- items: SparkDataFrame,
242
- user_features: Optional[SparkDataFrame] = None,
243
- item_features: Optional[SparkDataFrame] = None,
244
- filter_seen_items: bool = True, # noqa: ARG002
245
- ) -> SparkDataFrame:
246
- return self._predict_selected_pairs(users.crossJoin(items), user_features, item_features)
247
-
248
- def _predict_pairs(
249
- self,
250
- pairs: SparkDataFrame,
251
- log: Optional[SparkDataFrame] = None, # noqa: ARG002
252
- user_features: Optional[SparkDataFrame] = None,
253
- item_features: Optional[SparkDataFrame] = None,
254
- ) -> SparkDataFrame:
255
- return self._predict_selected_pairs(pairs, user_features, item_features)
256
-
257
- def _get_features(
258
- self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
259
- ) -> tuple[Optional[SparkDataFrame], Optional[int]]:
260
- """
261
- Get features from LightFM.
262
- LightFM has methods get_item_representations/get_user_representations,
263
- which accept object matrix and return features.
264
-
265
- :param ids: id item_idx/user_idx to get features for
266
- :param features: features for item_idx/user_idx
267
- :return: spark-dataframe with biases and vectors for users/items and vector size
268
- """
269
- entity = "item" if "item_idx" in ids.columns else "user"
270
- ids_list = ids.toPandas()[f"{entity}_idx"]
271
-
272
- # models without features use sparse matrix
273
- if features is None:
274
- matrix_width = getattr(self, f"fit_{entity}s").count()
275
- warm_ids = ids_list[ids_list < matrix_width]
276
- sparse_features = csr_matrix(
277
- (
278
- [1] * warm_ids.shape[0],
279
- (warm_ids, warm_ids),
280
- ),
281
- shape=(ids_list.max() + 1, matrix_width),
282
- )
283
- else:
284
- sparse_features = self._feature_table_to_csr(ids, features)
285
-
286
- biases, vectors = getattr(self.model, f"get_{entity}_representations")(sparse_features)
287
-
288
- embed_list = list(
289
- zip(
290
- ids_list,
291
- biases[ids_list].tolist(),
292
- vectors[ids_list].tolist(),
293
- )
294
- )
295
- lightfm_factors = State().session.createDataFrame(
296
- embed_list,
297
- schema=[
298
- f"{entity}_idx",
299
- f"{entity}_bias",
300
- f"{entity}_factors",
301
- ],
302
- )
303
- return lightfm_factors, self.model.no_components
@@ -1,332 +0,0 @@
1
- """
2
- MultVAE implementation
3
- (Variational Autoencoders for Collaborative Filtering)
4
- """
5
-
6
- from typing import Optional
7
-
8
- import numpy as np
9
- import torch
10
- import torch.nn.functional as sf
11
- from scipy.sparse import csr_matrix
12
- from sklearn.model_selection import GroupShuffleSplit
13
- from torch import nn
14
- from torch.optim import Adam
15
- from torch.optim.lr_scheduler import ReduceLROnPlateau
16
- from torch.utils.data import DataLoader, TensorDataset
17
-
18
- from replay.experimental.models.base_torch_rec import TorchRecommender
19
- from replay.utils import PandasDataFrame, SparkDataFrame
20
-
21
-
22
- class VAE(nn.Module):
23
- """Base variational autoencoder"""
24
-
25
- def __init__(
26
- self,
27
- item_count: int,
28
- latent_dim: int,
29
- hidden_dim: int = 600,
30
- dropout: float = 0.3,
31
- ):
32
- """
33
- :param item_count: number of items
34
- :param latent_dim: latent dimension size
35
- :param hidden_dim: hidden dimension size for encoder and decoder
36
- :param dropout: dropout coefficient
37
- """
38
- super().__init__()
39
-
40
- self.latent_dim = latent_dim
41
- self.encoder_dims = [item_count, hidden_dim, latent_dim * 2]
42
- self.decoder_dims = [latent_dim, hidden_dim, item_count]
43
-
44
- self.encoder = nn.ModuleList(
45
- [nn.Linear(d_in, d_out) for d_in, d_out in zip(self.encoder_dims[:-1], self.encoder_dims[1:])]
46
- )
47
- self.decoder = nn.ModuleList(
48
- [nn.Linear(d_in, d_out) for d_in, d_out in zip(self.decoder_dims[:-1], self.decoder_dims[1:])]
49
- )
50
- self.dropout = nn.Dropout(dropout)
51
- self.activation = torch.nn.ReLU()
52
-
53
- for layer in self.encoder:
54
- self.weight_init(layer)
55
-
56
- for layer in self.decoder:
57
- self.weight_init(layer)
58
-
59
- def encode(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
60
- """Encode"""
61
- hidden = sf.normalize(batch, p=2, dim=1)
62
- hidden = self.dropout(hidden)
63
-
64
- for layer in self.encoder[:-1]:
65
- hidden = layer(hidden)
66
- hidden = self.activation(hidden)
67
-
68
- hidden = self.encoder[-1](hidden)
69
- mu_latent = hidden[:, : self.latent_dim]
70
- logvar_latent = hidden[:, self.latent_dim :]
71
- return mu_latent, logvar_latent
72
-
73
- def reparameterize(self, mu_latent: torch.Tensor, logvar_latent: torch.Tensor) -> torch.Tensor:
74
- """Reparametrization trick"""
75
-
76
- if self.training:
77
- std = torch.exp(0.5 * logvar_latent)
78
- eps = torch.randn_like(std)
79
- return eps * std + mu_latent
80
- return mu_latent
81
-
82
- def decode(self, z_latent: torch.Tensor) -> torch.Tensor:
83
- """Decode"""
84
- hidden = z_latent
85
- for layer in self.decoder[:-1]:
86
- hidden = layer(hidden)
87
- hidden = self.activation(hidden)
88
- return self.decoder[-1](hidden)
89
-
90
- def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
91
- """
92
- :param batch: user batch
93
- :return: output, expectation and logarithm of variation
94
- """
95
- mu_latent, logvar_latent = self.encode(batch)
96
- z_latent = self.reparameterize(mu_latent, logvar_latent)
97
- return self.decode(z_latent), mu_latent, logvar_latent
98
-
99
- @staticmethod
100
- def weight_init(layer: nn.Module):
101
- """
102
- Xavier initialization
103
-
104
- :param layer: layer of a model
105
- """
106
- if isinstance(layer, nn.Linear):
107
- nn.init.xavier_normal_(layer.weight.data)
108
- layer.bias.data.normal_(0.0, 0.001)
109
-
110
-
111
- class MultVAE(TorchRecommender):
112
- """`Variational Autoencoders for Collaborative Filtering
113
- <https://arxiv.org/pdf/1802.05814.pdf>`_"""
114
-
115
- num_workers: int = 0
116
- batch_size_users: int = 5000
117
- patience: int = 10
118
- n_saved: int = 2
119
- valid_split_size: float = 0.1
120
- seed: int = 42
121
- can_predict_cold_users = True
122
- train_user_batch: csr_matrix
123
- valid_user_batch: csr_matrix
124
- _search_space = {
125
- "learning_rate": {"type": "loguniform", "args": [0.0001, 0.5]},
126
- "epochs": {"type": "int", "args": [100, 100]},
127
- "latent_dim": {"type": "int", "args": [200, 200]},
128
- "hidden_dim": {"type": "int", "args": [600, 600]},
129
- "dropout": {"type": "uniform", "args": [0, 0.5]},
130
- "anneal": {"type": "uniform", "args": [0.2, 1]},
131
- "l2_reg": {"type": "loguniform", "args": [1e-9, 5]},
132
- "factor": {"type": "uniform", "args": [0.2, 0.2]},
133
- "patience": {"type": "int", "args": [3, 3]},
134
- }
135
-
136
- def __init__(
137
- self,
138
- learning_rate: float = 0.01,
139
- epochs: int = 100,
140
- latent_dim: int = 200,
141
- hidden_dim: int = 600,
142
- dropout: float = 0.3,
143
- anneal: float = 0.1,
144
- l2_reg: float = 0,
145
- factor: float = 0.2,
146
- patience: int = 3,
147
- ):
148
- """
149
- :param learning_rate: learning rate
150
- :param epochs: number of epochs to train model
151
- :param latent_dim: latent dimension size for user vectors
152
- :param hidden_dim: hidden dimension size for encoder and decoder
153
- :param dropout: dropout coefficient
154
- :param anneal: anneal coefficient [0,1]
155
- :param l2_reg: l2 regularization term
156
- :param factor: ReduceLROnPlateau reducing factor. new_lr = lr * factor
157
- :param patience: number of non-improved epochs before reducing lr
158
- """
159
- super().__init__()
160
- self.learning_rate = learning_rate
161
- self.epochs = epochs
162
- self.latent_dim = latent_dim
163
- self.hidden_dim = hidden_dim
164
- self.dropout = dropout
165
- self.anneal = anneal
166
- self.l2_reg = l2_reg
167
- self.factor = factor
168
- self.patience = patience
169
-
170
- @property
171
- def _init_args(self):
172
- return {
173
- "learning_rate": self.learning_rate,
174
- "epochs": self.epochs,
175
- "latent_dim": self.latent_dim,
176
- "hidden_dim": self.hidden_dim,
177
- "dropout": self.dropout,
178
- "anneal": self.anneal,
179
- "l2_reg": self.l2_reg,
180
- "factor": self.factor,
181
- "patience": self.patience,
182
- }
183
-
184
- def _get_data_loader(
185
- self, data: PandasDataFrame, shuffle: bool = True
186
- ) -> tuple[csr_matrix, DataLoader, np.ndarray]:
187
- """get data loader and matrix with data"""
188
- users_count = data["user_idx"].value_counts().count()
189
- user_idx = data["user_idx"].astype("category").cat
190
- user_batch = csr_matrix(
191
- (
192
- np.ones(len(data["user_idx"])),
193
- ([user_idx.codes.values, data["item_idx"].values]),
194
- ),
195
- shape=(users_count, self._item_dim),
196
- )
197
- data_loader = DataLoader(
198
- TensorDataset(torch.arange(users_count).long()),
199
- batch_size=self.batch_size_users,
200
- shuffle=shuffle,
201
- num_workers=self.num_workers,
202
- )
203
-
204
- return user_batch, data_loader, user_idx.categories.values
205
-
206
- def _fit(
207
- self,
208
- log: SparkDataFrame,
209
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
210
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
211
- ) -> None:
212
- self.logger.debug("Creating batch")
213
- data = log.select("user_idx", "item_idx").toPandas()
214
- splitter = GroupShuffleSplit(n_splits=1, test_size=self.valid_split_size, random_state=self.seed)
215
- train_idx, valid_idx = next(splitter.split(data, groups=data["user_idx"]))
216
- train_data, valid_data = data.iloc[train_idx], data.iloc[valid_idx]
217
-
218
- self.train_user_batch, train_data_loader, _ = self._get_data_loader(train_data)
219
- self.valid_user_batch, valid_data_loader, _ = self._get_data_loader(valid_data, False)
220
-
221
- self.logger.debug("Training VAE")
222
- self.model = VAE(
223
- item_count=self._item_dim,
224
- latent_dim=self.latent_dim,
225
- hidden_dim=self.hidden_dim,
226
- dropout=self.dropout,
227
- ).to(self.device)
228
- optimizer = Adam(
229
- self.model.parameters(),
230
- lr=self.learning_rate,
231
- weight_decay=self.l2_reg / self.batch_size_users,
232
- )
233
- lr_scheduler = ReduceLROnPlateau(optimizer, factor=self.factor, patience=self.patience)
234
-
235
- self.train(
236
- train_data_loader,
237
- valid_data_loader,
238
- optimizer,
239
- lr_scheduler,
240
- self.epochs,
241
- "multvae",
242
- )
243
-
244
- def _loss(self, y_pred, y_true, mu_latent, logvar_latent):
245
- log_softmax_var = sf.log_softmax(y_pred, dim=1)
246
- bce = -(log_softmax_var * y_true).sum(dim=1).mean()
247
- kld = (
248
- -0.5
249
- * torch.sum(
250
- 1 + logvar_latent - mu_latent.pow(2) - logvar_latent.exp(),
251
- dim=1,
252
- ).mean()
253
- )
254
- return bce + self.anneal * kld
255
-
256
- def _batch_pass(self, batch, model):
257
- full_batch = self.train_user_batch if model.training else self.valid_user_batch
258
- user_batch = torch.FloatTensor(full_batch[batch[0]].toarray()).to(self.device)
259
- pred_user_batch, latent_mu, latent_logvar = self.model.forward(user_batch)
260
- return {
261
- "y_pred": pred_user_batch,
262
- "y_true": user_batch,
263
- "mu_latent": latent_mu,
264
- "logvar_latent": latent_logvar,
265
- }
266
-
267
- @staticmethod
268
- def _predict_pairs_inner(
269
- model: nn.Module,
270
- user_idx: int,
271
- items_np_history: np.ndarray,
272
- items_np_to_pred: np.ndarray,
273
- item_count: int,
274
- cnt: Optional[int] = None,
275
- ) -> SparkDataFrame:
276
- model.eval()
277
- with torch.no_grad():
278
- user_batch = torch.zeros((1, item_count))
279
- user_batch[0, items_np_history] = 1
280
- user_recs = sf.softmax(model(user_batch)[0][0].detach(), dim=0)
281
- if cnt is not None:
282
- best_item_idx = (torch.argsort(user_recs[items_np_to_pred], descending=True)[:cnt]).numpy()
283
- items_np_to_pred = items_np_to_pred[best_item_idx]
284
- return PandasDataFrame(
285
- {
286
- "user_idx": np.array(items_np_to_pred.shape[0] * [user_idx]),
287
- "item_idx": items_np_to_pred,
288
- "relevance": user_recs[items_np_to_pred],
289
- }
290
- )
291
-
292
- @staticmethod
293
- def _predict_by_user(
294
- pandas_df: PandasDataFrame,
295
- model: nn.Module,
296
- items_np: np.ndarray,
297
- k: int,
298
- item_count: int,
299
- ) -> PandasDataFrame:
300
- return MultVAE._predict_pairs_inner(
301
- model=model,
302
- user_idx=pandas_df["user_idx"][0],
303
- items_np_history=pandas_df["item_idx"].values,
304
- items_np_to_pred=items_np,
305
- item_count=item_count,
306
- cnt=min(len(pandas_df) + k, len(items_np)),
307
- )
308
-
309
- @staticmethod
310
- def _predict_by_user_pairs(
311
- pandas_df: PandasDataFrame,
312
- model: nn.Module,
313
- item_count: int,
314
- ) -> PandasDataFrame:
315
- return MultVAE._predict_pairs_inner(
316
- model=model,
317
- user_idx=pandas_df["user_idx"][0],
318
- items_np_history=np.array(pandas_df["item_idx_history"][0]),
319
- items_np_to_pred=np.array(pandas_df["item_idx_to_pred"][0]),
320
- item_count=item_count,
321
- cnt=None,
322
- )
323
-
324
- def _load_model(self, path: str):
325
- self.model = VAE(
326
- item_count=self._item_dim,
327
- latent_dim=self.latent_dim,
328
- hidden_dim=self.hidden_dim,
329
- dropout=self.dropout,
330
- ).to(self.device)
331
- self.model.load_state_dict(torch.load(path))
332
- self.model.eval()