replay-rec 0.20.3__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +62 -0
- replay/experimental/metrics/base_metric.py +603 -0
- replay/experimental/metrics/coverage.py +97 -0
- replay/experimental/metrics/experiment.py +175 -0
- replay/experimental/metrics/hitrate.py +26 -0
- replay/experimental/metrics/map.py +30 -0
- replay/experimental/metrics/mrr.py +18 -0
- replay/experimental/metrics/ncis_precision.py +31 -0
- replay/experimental/metrics/ndcg.py +49 -0
- replay/experimental/metrics/precision.py +22 -0
- replay/experimental/metrics/recall.py +25 -0
- replay/experimental/metrics/rocauc.py +49 -0
- replay/experimental/metrics/surprisal.py +90 -0
- replay/experimental/metrics/unexpectedness.py +76 -0
- replay/experimental/models/__init__.py +50 -0
- replay/experimental/models/admm_slim.py +257 -0
- replay/experimental/models/base_neighbour_rec.py +200 -0
- replay/experimental/models/base_rec.py +1386 -0
- replay/experimental/models/base_torch_rec.py +234 -0
- replay/experimental/models/cql.py +454 -0
- replay/experimental/models/ddpg.py +932 -0
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +189 -0
- replay/experimental/models/dt4rec/gpt1.py +401 -0
- replay/experimental/models/dt4rec/trainer.py +127 -0
- replay/experimental/models/dt4rec/utils.py +264 -0
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
- replay/experimental/models/hierarchical_recommender.py +331 -0
- replay/experimental/models/implicit_wrap.py +131 -0
- replay/experimental/models/lightfm_wrap.py +303 -0
- replay/experimental/models/mult_vae.py +332 -0
- replay/experimental/models/neural_ts.py +986 -0
- replay/experimental/models/neuromf.py +406 -0
- replay/experimental/models/scala_als.py +293 -0
- replay/experimental/models/u_lin_ucb.py +115 -0
- replay/experimental/nn/data/__init__.py +1 -0
- replay/experimental/nn/data/schema_builder.py +102 -0
- replay/experimental/preprocessing/__init__.py +3 -0
- replay/experimental/preprocessing/data_preparator.py +839 -0
- replay/experimental/preprocessing/padder.py +229 -0
- replay/experimental/preprocessing/sequence_generator.py +208 -0
- replay/experimental/scenarios/__init__.py +1 -0
- replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
- replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +117 -0
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +24 -0
- replay/experimental/utils/model_handler.py +186 -0
- replay/experimental/utils/session_handler.py +44 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
- {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +61 -6
- {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from os.path import join
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.sparse import csr_matrix, diags, hstack
|
|
7
|
+
from sklearn.preprocessing import MinMaxScaler
|
|
8
|
+
|
|
9
|
+
from replay.data import get_schema
|
|
10
|
+
from replay.experimental.models.base_rec import HybridRecommender
|
|
11
|
+
from replay.experimental.utils.session_handler import State
|
|
12
|
+
from replay.preprocessing import CSRConverter
|
|
13
|
+
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
14
|
+
from replay.utils.spark_utils import check_numeric, load_pickled_from_parquet, save_picklable_to_parquet
|
|
15
|
+
|
|
16
|
+
if PYSPARK_AVAILABLE:
|
|
17
|
+
import pyspark.sql.functions as sf
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LightFMWrap(HybridRecommender):
|
|
21
|
+
"""Wrapper for LightFM."""
|
|
22
|
+
|
|
23
|
+
epochs: int = 10
|
|
24
|
+
_search_space = {
|
|
25
|
+
"loss": {
|
|
26
|
+
"type": "categorical",
|
|
27
|
+
"args": ["logistic", "bpr", "warp", "warp-kos"],
|
|
28
|
+
},
|
|
29
|
+
"no_components": {"type": "loguniform_int", "args": [8, 512]},
|
|
30
|
+
}
|
|
31
|
+
user_feat_scaler: Optional[MinMaxScaler] = None
|
|
32
|
+
item_feat_scaler: Optional[MinMaxScaler] = None
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
no_components: int = 128,
|
|
37
|
+
loss: str = "warp",
|
|
38
|
+
random_state: Optional[int] = None,
|
|
39
|
+
):
|
|
40
|
+
np.random.seed(42)
|
|
41
|
+
self.no_components = no_components
|
|
42
|
+
self.loss = loss
|
|
43
|
+
self.random_state = random_state
|
|
44
|
+
cpu_count = os.cpu_count()
|
|
45
|
+
self.num_threads = cpu_count if cpu_count is not None else 1
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def _init_args(self):
|
|
49
|
+
return {
|
|
50
|
+
"no_components": self.no_components,
|
|
51
|
+
"loss": self.loss,
|
|
52
|
+
"random_state": self.random_state,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
def _save_model(self, path: str):
|
|
56
|
+
save_picklable_to_parquet(self.model, join(path, "model"))
|
|
57
|
+
save_picklable_to_parquet(self.user_feat_scaler, join(path, "user_feat_scaler"))
|
|
58
|
+
save_picklable_to_parquet(self.item_feat_scaler, join(path, "item_feat_scaler"))
|
|
59
|
+
|
|
60
|
+
def _load_model(self, path: str):
|
|
61
|
+
self.model = load_pickled_from_parquet(join(path, "model"))
|
|
62
|
+
self.user_feat_scaler = load_pickled_from_parquet(join(path, "user_feat_scaler"))
|
|
63
|
+
self.item_feat_scaler = load_pickled_from_parquet(join(path, "item_feat_scaler"))
|
|
64
|
+
|
|
65
|
+
def _feature_table_to_csr(
|
|
66
|
+
self,
|
|
67
|
+
log_ids_list: SparkDataFrame,
|
|
68
|
+
feature_table: Optional[SparkDataFrame] = None,
|
|
69
|
+
) -> Optional[csr_matrix]:
|
|
70
|
+
"""
|
|
71
|
+
Transform features to sparse matrix
|
|
72
|
+
Matrix consists of two parts:
|
|
73
|
+
1) Left one is a ohe-hot encoding of user and item ids.
|
|
74
|
+
Matrix size is: number of users or items * number of user or items in fit.
|
|
75
|
+
Cold users and items are represented with empty strings
|
|
76
|
+
2) Right one is a numerical features, passed with feature_table.
|
|
77
|
+
MinMaxScaler is applied per column, and then value is divided by the row sum.
|
|
78
|
+
|
|
79
|
+
:param feature_table: dataframe with ``user_idx`` or ``item_idx``,
|
|
80
|
+
other columns are features.
|
|
81
|
+
:param log_ids_list: dataframe with ``user_idx`` or ``item_idx``,
|
|
82
|
+
containing unique ids from log.
|
|
83
|
+
:returns: feature matrix
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
if feature_table is None:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
check_numeric(feature_table)
|
|
90
|
+
log_ids_list = log_ids_list.distinct()
|
|
91
|
+
entity = "item" if "item_idx" in feature_table.columns else "user"
|
|
92
|
+
idx_col_name = f"{entity}_idx"
|
|
93
|
+
|
|
94
|
+
# filter features by log
|
|
95
|
+
feature_table = feature_table.join(log_ids_list, on=idx_col_name, how="inner")
|
|
96
|
+
|
|
97
|
+
fit_dim = getattr(self, f"_{entity}_dim")
|
|
98
|
+
matrix_height = max(
|
|
99
|
+
fit_dim,
|
|
100
|
+
log_ids_list.select(sf.max(idx_col_name)).first()[0] + 1,
|
|
101
|
+
)
|
|
102
|
+
if not feature_table.rdd.isEmpty():
|
|
103
|
+
matrix_height = max(
|
|
104
|
+
matrix_height,
|
|
105
|
+
feature_table.select(sf.max(idx_col_name)).first()[0] + 1,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
features_np = (
|
|
109
|
+
feature_table.select(
|
|
110
|
+
idx_col_name,
|
|
111
|
+
# first column contains id, next contain features
|
|
112
|
+
*sorted(set(feature_table.columns).difference({idx_col_name})),
|
|
113
|
+
)
|
|
114
|
+
.toPandas()
|
|
115
|
+
.to_numpy()
|
|
116
|
+
)
|
|
117
|
+
entities_ids = features_np[:, 0]
|
|
118
|
+
features_np = features_np[:, 1:]
|
|
119
|
+
number_of_features = features_np.shape[1]
|
|
120
|
+
|
|
121
|
+
all_ids_list = log_ids_list.toPandas().to_numpy().ravel()
|
|
122
|
+
entities_seen_in_fit = all_ids_list[all_ids_list < fit_dim]
|
|
123
|
+
|
|
124
|
+
entity_id_features = csr_matrix(
|
|
125
|
+
(
|
|
126
|
+
[1.0] * entities_seen_in_fit.shape[0],
|
|
127
|
+
(entities_seen_in_fit, entities_seen_in_fit),
|
|
128
|
+
),
|
|
129
|
+
shape=(matrix_height, fit_dim),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
scaler_name = f"{entity}_feat_scaler"
|
|
133
|
+
if getattr(self, scaler_name) is None:
|
|
134
|
+
if not features_np.size:
|
|
135
|
+
msg = f"features for {entity}s from log are absent"
|
|
136
|
+
raise ValueError(msg)
|
|
137
|
+
setattr(self, scaler_name, MinMaxScaler().fit(features_np))
|
|
138
|
+
|
|
139
|
+
if features_np.size:
|
|
140
|
+
features_np = getattr(self, scaler_name).transform(features_np)
|
|
141
|
+
sparse_features = csr_matrix(
|
|
142
|
+
(
|
|
143
|
+
features_np.ravel(),
|
|
144
|
+
(
|
|
145
|
+
np.repeat(entities_ids, number_of_features),
|
|
146
|
+
np.tile(
|
|
147
|
+
np.arange(number_of_features),
|
|
148
|
+
entities_ids.shape[0],
|
|
149
|
+
),
|
|
150
|
+
),
|
|
151
|
+
),
|
|
152
|
+
shape=(matrix_height, number_of_features),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
else:
|
|
156
|
+
sparse_features = csr_matrix((matrix_height, number_of_features))
|
|
157
|
+
|
|
158
|
+
concat_features = hstack([entity_id_features, sparse_features])
|
|
159
|
+
concat_features_sum = concat_features.sum(axis=1).A.ravel()
|
|
160
|
+
mask = concat_features_sum != 0.0
|
|
161
|
+
concat_features_sum[mask] = 1.0 / concat_features_sum[mask]
|
|
162
|
+
return diags(concat_features_sum, format="csr") @ concat_features
|
|
163
|
+
|
|
164
|
+
def _fit(
|
|
165
|
+
self,
|
|
166
|
+
log: SparkDataFrame,
|
|
167
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
168
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
169
|
+
) -> None:
|
|
170
|
+
from lightfm import LightFM
|
|
171
|
+
|
|
172
|
+
self.user_feat_scaler = None
|
|
173
|
+
self.item_feat_scaler = None
|
|
174
|
+
|
|
175
|
+
interactions_matrix = CSRConverter(
|
|
176
|
+
first_dim_column="user_idx",
|
|
177
|
+
second_dim_column="item_idx",
|
|
178
|
+
data_column="relevance",
|
|
179
|
+
row_count=self._user_dim,
|
|
180
|
+
column_count=self._item_dim,
|
|
181
|
+
).transform(log)
|
|
182
|
+
csr_item_features = self._feature_table_to_csr(log.select("item_idx").distinct(), item_features)
|
|
183
|
+
csr_user_features = self._feature_table_to_csr(log.select("user_idx").distinct(), user_features)
|
|
184
|
+
|
|
185
|
+
if user_features is not None:
|
|
186
|
+
self.can_predict_cold_users = True
|
|
187
|
+
if item_features is not None:
|
|
188
|
+
self.can_predict_cold_items = True
|
|
189
|
+
|
|
190
|
+
self.model = LightFM(
|
|
191
|
+
loss=self.loss,
|
|
192
|
+
no_components=self.no_components,
|
|
193
|
+
random_state=self.random_state,
|
|
194
|
+
).fit(
|
|
195
|
+
interactions=interactions_matrix,
|
|
196
|
+
epochs=self.epochs,
|
|
197
|
+
num_threads=self.num_threads,
|
|
198
|
+
item_features=csr_item_features,
|
|
199
|
+
user_features=csr_user_features,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def _predict_selected_pairs(
|
|
203
|
+
self,
|
|
204
|
+
pairs: SparkDataFrame,
|
|
205
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
206
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
207
|
+
):
|
|
208
|
+
def predict_by_user(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
209
|
+
pandas_df["relevance"] = model.predict(
|
|
210
|
+
user_ids=pandas_df["user_idx"].to_numpy(),
|
|
211
|
+
item_ids=pandas_df["item_idx"].to_numpy(),
|
|
212
|
+
item_features=csr_item_features,
|
|
213
|
+
user_features=csr_user_features,
|
|
214
|
+
)
|
|
215
|
+
return pandas_df
|
|
216
|
+
|
|
217
|
+
model = self.model
|
|
218
|
+
|
|
219
|
+
if self.can_predict_cold_users and user_features is None:
|
|
220
|
+
msg = "User features are missing for predict"
|
|
221
|
+
raise ValueError(msg)
|
|
222
|
+
if self.can_predict_cold_items and item_features is None:
|
|
223
|
+
msg = "Item features are missing for predict"
|
|
224
|
+
raise ValueError(msg)
|
|
225
|
+
|
|
226
|
+
csr_item_features = self._feature_table_to_csr(pairs.select("item_idx").distinct(), item_features)
|
|
227
|
+
csr_user_features = self._feature_table_to_csr(pairs.select("user_idx").distinct(), user_features)
|
|
228
|
+
rec_schema = get_schema(
|
|
229
|
+
query_column="user_idx",
|
|
230
|
+
item_column="item_idx",
|
|
231
|
+
rating_column="relevance",
|
|
232
|
+
has_timestamp=False,
|
|
233
|
+
)
|
|
234
|
+
return pairs.groupby("user_idx").applyInPandas(predict_by_user, rec_schema)
|
|
235
|
+
|
|
236
|
+
def _predict(
|
|
237
|
+
self,
|
|
238
|
+
log: SparkDataFrame, # noqa: ARG002
|
|
239
|
+
k: int, # noqa: ARG002
|
|
240
|
+
users: SparkDataFrame,
|
|
241
|
+
items: SparkDataFrame,
|
|
242
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
243
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
244
|
+
filter_seen_items: bool = True, # noqa: ARG002
|
|
245
|
+
) -> SparkDataFrame:
|
|
246
|
+
return self._predict_selected_pairs(users.crossJoin(items), user_features, item_features)
|
|
247
|
+
|
|
248
|
+
def _predict_pairs(
|
|
249
|
+
self,
|
|
250
|
+
pairs: SparkDataFrame,
|
|
251
|
+
log: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
252
|
+
user_features: Optional[SparkDataFrame] = None,
|
|
253
|
+
item_features: Optional[SparkDataFrame] = None,
|
|
254
|
+
) -> SparkDataFrame:
|
|
255
|
+
return self._predict_selected_pairs(pairs, user_features, item_features)
|
|
256
|
+
|
|
257
|
+
def _get_features(
|
|
258
|
+
self, ids: SparkDataFrame, features: Optional[SparkDataFrame]
|
|
259
|
+
) -> tuple[Optional[SparkDataFrame], Optional[int]]:
|
|
260
|
+
"""
|
|
261
|
+
Get features from LightFM.
|
|
262
|
+
LightFM has methods get_item_representations/get_user_representations,
|
|
263
|
+
which accept object matrix and return features.
|
|
264
|
+
|
|
265
|
+
:param ids: id item_idx/user_idx to get features for
|
|
266
|
+
:param features: features for item_idx/user_idx
|
|
267
|
+
:return: spark-dataframe with biases and vectors for users/items and vector size
|
|
268
|
+
"""
|
|
269
|
+
entity = "item" if "item_idx" in ids.columns else "user"
|
|
270
|
+
ids_list = ids.toPandas()[f"{entity}_idx"]
|
|
271
|
+
|
|
272
|
+
# models without features use sparse matrix
|
|
273
|
+
if features is None:
|
|
274
|
+
matrix_width = getattr(self, f"fit_{entity}s").count()
|
|
275
|
+
warm_ids = ids_list[ids_list < matrix_width]
|
|
276
|
+
sparse_features = csr_matrix(
|
|
277
|
+
(
|
|
278
|
+
[1] * warm_ids.shape[0],
|
|
279
|
+
(warm_ids, warm_ids),
|
|
280
|
+
),
|
|
281
|
+
shape=(ids_list.max() + 1, matrix_width),
|
|
282
|
+
)
|
|
283
|
+
else:
|
|
284
|
+
sparse_features = self._feature_table_to_csr(ids, features)
|
|
285
|
+
|
|
286
|
+
biases, vectors = getattr(self.model, f"get_{entity}_representations")(sparse_features)
|
|
287
|
+
|
|
288
|
+
embed_list = list(
|
|
289
|
+
zip(
|
|
290
|
+
ids_list,
|
|
291
|
+
biases[ids_list].tolist(),
|
|
292
|
+
vectors[ids_list].tolist(),
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
lightfm_factors = State().session.createDataFrame(
|
|
296
|
+
embed_list,
|
|
297
|
+
schema=[
|
|
298
|
+
f"{entity}_idx",
|
|
299
|
+
f"{entity}_bias",
|
|
300
|
+
f"{entity}_factors",
|
|
301
|
+
],
|
|
302
|
+
)
|
|
303
|
+
return lightfm_factors, self.model.no_components
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MultVAE implementation
|
|
3
|
+
(Variational Autoencoders for Collaborative Filtering)
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as sf
|
|
11
|
+
from scipy.sparse import csr_matrix
|
|
12
|
+
from sklearn.model_selection import GroupShuffleSplit
|
|
13
|
+
from torch import nn
|
|
14
|
+
from torch.optim import Adam
|
|
15
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
16
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
17
|
+
|
|
18
|
+
from replay.experimental.models.base_torch_rec import TorchRecommender
|
|
19
|
+
from replay.utils import PandasDataFrame, SparkDataFrame
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class VAE(nn.Module):
|
|
23
|
+
"""Base variational autoencoder"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
item_count: int,
|
|
28
|
+
latent_dim: int,
|
|
29
|
+
hidden_dim: int = 600,
|
|
30
|
+
dropout: float = 0.3,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
:param item_count: number of items
|
|
34
|
+
:param latent_dim: latent dimension size
|
|
35
|
+
:param hidden_dim: hidden dimension size for encoder and decoder
|
|
36
|
+
:param dropout: dropout coefficient
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.latent_dim = latent_dim
|
|
41
|
+
self.encoder_dims = [item_count, hidden_dim, latent_dim * 2]
|
|
42
|
+
self.decoder_dims = [latent_dim, hidden_dim, item_count]
|
|
43
|
+
|
|
44
|
+
self.encoder = nn.ModuleList(
|
|
45
|
+
[nn.Linear(d_in, d_out) for d_in, d_out in zip(self.encoder_dims[:-1], self.encoder_dims[1:])]
|
|
46
|
+
)
|
|
47
|
+
self.decoder = nn.ModuleList(
|
|
48
|
+
[nn.Linear(d_in, d_out) for d_in, d_out in zip(self.decoder_dims[:-1], self.decoder_dims[1:])]
|
|
49
|
+
)
|
|
50
|
+
self.dropout = nn.Dropout(dropout)
|
|
51
|
+
self.activation = torch.nn.ReLU()
|
|
52
|
+
|
|
53
|
+
for layer in self.encoder:
|
|
54
|
+
self.weight_init(layer)
|
|
55
|
+
|
|
56
|
+
for layer in self.decoder:
|
|
57
|
+
self.weight_init(layer)
|
|
58
|
+
|
|
59
|
+
def encode(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
60
|
+
"""Encode"""
|
|
61
|
+
hidden = sf.normalize(batch, p=2, dim=1)
|
|
62
|
+
hidden = self.dropout(hidden)
|
|
63
|
+
|
|
64
|
+
for layer in self.encoder[:-1]:
|
|
65
|
+
hidden = layer(hidden)
|
|
66
|
+
hidden = self.activation(hidden)
|
|
67
|
+
|
|
68
|
+
hidden = self.encoder[-1](hidden)
|
|
69
|
+
mu_latent = hidden[:, : self.latent_dim]
|
|
70
|
+
logvar_latent = hidden[:, self.latent_dim :]
|
|
71
|
+
return mu_latent, logvar_latent
|
|
72
|
+
|
|
73
|
+
def reparameterize(self, mu_latent: torch.Tensor, logvar_latent: torch.Tensor) -> torch.Tensor:
|
|
74
|
+
"""Reparametrization trick"""
|
|
75
|
+
|
|
76
|
+
if self.training:
|
|
77
|
+
std = torch.exp(0.5 * logvar_latent)
|
|
78
|
+
eps = torch.randn_like(std)
|
|
79
|
+
return eps * std + mu_latent
|
|
80
|
+
return mu_latent
|
|
81
|
+
|
|
82
|
+
def decode(self, z_latent: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
"""Decode"""
|
|
84
|
+
hidden = z_latent
|
|
85
|
+
for layer in self.decoder[:-1]:
|
|
86
|
+
hidden = layer(hidden)
|
|
87
|
+
hidden = self.activation(hidden)
|
|
88
|
+
return self.decoder[-1](hidden)
|
|
89
|
+
|
|
90
|
+
def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
91
|
+
"""
|
|
92
|
+
:param batch: user batch
|
|
93
|
+
:return: output, expectation and logarithm of variation
|
|
94
|
+
"""
|
|
95
|
+
mu_latent, logvar_latent = self.encode(batch)
|
|
96
|
+
z_latent = self.reparameterize(mu_latent, logvar_latent)
|
|
97
|
+
return self.decode(z_latent), mu_latent, logvar_latent
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def weight_init(layer: nn.Module):
|
|
101
|
+
"""
|
|
102
|
+
Xavier initialization
|
|
103
|
+
|
|
104
|
+
:param layer: layer of a model
|
|
105
|
+
"""
|
|
106
|
+
if isinstance(layer, nn.Linear):
|
|
107
|
+
nn.init.xavier_normal_(layer.weight.data)
|
|
108
|
+
layer.bias.data.normal_(0.0, 0.001)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class MultVAE(TorchRecommender):
|
|
112
|
+
"""`Variational Autoencoders for Collaborative Filtering
|
|
113
|
+
<https://arxiv.org/pdf/1802.05814.pdf>`_"""
|
|
114
|
+
|
|
115
|
+
num_workers: int = 0
|
|
116
|
+
batch_size_users: int = 5000
|
|
117
|
+
patience: int = 10
|
|
118
|
+
n_saved: int = 2
|
|
119
|
+
valid_split_size: float = 0.1
|
|
120
|
+
seed: int = 42
|
|
121
|
+
can_predict_cold_users = True
|
|
122
|
+
train_user_batch: csr_matrix
|
|
123
|
+
valid_user_batch: csr_matrix
|
|
124
|
+
_search_space = {
|
|
125
|
+
"learning_rate": {"type": "loguniform", "args": [0.0001, 0.5]},
|
|
126
|
+
"epochs": {"type": "int", "args": [100, 100]},
|
|
127
|
+
"latent_dim": {"type": "int", "args": [200, 200]},
|
|
128
|
+
"hidden_dim": {"type": "int", "args": [600, 600]},
|
|
129
|
+
"dropout": {"type": "uniform", "args": [0, 0.5]},
|
|
130
|
+
"anneal": {"type": "uniform", "args": [0.2, 1]},
|
|
131
|
+
"l2_reg": {"type": "loguniform", "args": [1e-9, 5]},
|
|
132
|
+
"factor": {"type": "uniform", "args": [0.2, 0.2]},
|
|
133
|
+
"patience": {"type": "int", "args": [3, 3]},
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
learning_rate: float = 0.01,
|
|
139
|
+
epochs: int = 100,
|
|
140
|
+
latent_dim: int = 200,
|
|
141
|
+
hidden_dim: int = 600,
|
|
142
|
+
dropout: float = 0.3,
|
|
143
|
+
anneal: float = 0.1,
|
|
144
|
+
l2_reg: float = 0,
|
|
145
|
+
factor: float = 0.2,
|
|
146
|
+
patience: int = 3,
|
|
147
|
+
):
|
|
148
|
+
"""
|
|
149
|
+
:param learning_rate: learning rate
|
|
150
|
+
:param epochs: number of epochs to train model
|
|
151
|
+
:param latent_dim: latent dimension size for user vectors
|
|
152
|
+
:param hidden_dim: hidden dimension size for encoder and decoder
|
|
153
|
+
:param dropout: dropout coefficient
|
|
154
|
+
:param anneal: anneal coefficient [0,1]
|
|
155
|
+
:param l2_reg: l2 regularization term
|
|
156
|
+
:param factor: ReduceLROnPlateau reducing factor. new_lr = lr * factor
|
|
157
|
+
:param patience: number of non-improved epochs before reducing lr
|
|
158
|
+
"""
|
|
159
|
+
super().__init__()
|
|
160
|
+
self.learning_rate = learning_rate
|
|
161
|
+
self.epochs = epochs
|
|
162
|
+
self.latent_dim = latent_dim
|
|
163
|
+
self.hidden_dim = hidden_dim
|
|
164
|
+
self.dropout = dropout
|
|
165
|
+
self.anneal = anneal
|
|
166
|
+
self.l2_reg = l2_reg
|
|
167
|
+
self.factor = factor
|
|
168
|
+
self.patience = patience
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def _init_args(self):
|
|
172
|
+
return {
|
|
173
|
+
"learning_rate": self.learning_rate,
|
|
174
|
+
"epochs": self.epochs,
|
|
175
|
+
"latent_dim": self.latent_dim,
|
|
176
|
+
"hidden_dim": self.hidden_dim,
|
|
177
|
+
"dropout": self.dropout,
|
|
178
|
+
"anneal": self.anneal,
|
|
179
|
+
"l2_reg": self.l2_reg,
|
|
180
|
+
"factor": self.factor,
|
|
181
|
+
"patience": self.patience,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def _get_data_loader(
|
|
185
|
+
self, data: PandasDataFrame, shuffle: bool = True
|
|
186
|
+
) -> tuple[csr_matrix, DataLoader, np.ndarray]:
|
|
187
|
+
"""get data loader and matrix with data"""
|
|
188
|
+
users_count = data["user_idx"].value_counts().count()
|
|
189
|
+
user_idx = data["user_idx"].astype("category").cat
|
|
190
|
+
user_batch = csr_matrix(
|
|
191
|
+
(
|
|
192
|
+
np.ones(len(data["user_idx"])),
|
|
193
|
+
([user_idx.codes.values, data["item_idx"].values]),
|
|
194
|
+
),
|
|
195
|
+
shape=(users_count, self._item_dim),
|
|
196
|
+
)
|
|
197
|
+
data_loader = DataLoader(
|
|
198
|
+
TensorDataset(torch.arange(users_count).long()),
|
|
199
|
+
batch_size=self.batch_size_users,
|
|
200
|
+
shuffle=shuffle,
|
|
201
|
+
num_workers=self.num_workers,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return user_batch, data_loader, user_idx.categories.values
|
|
205
|
+
|
|
206
|
+
def _fit(
|
|
207
|
+
self,
|
|
208
|
+
log: SparkDataFrame,
|
|
209
|
+
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
210
|
+
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
211
|
+
) -> None:
|
|
212
|
+
self.logger.debug("Creating batch")
|
|
213
|
+
data = log.select("user_idx", "item_idx").toPandas()
|
|
214
|
+
splitter = GroupShuffleSplit(n_splits=1, test_size=self.valid_split_size, random_state=self.seed)
|
|
215
|
+
train_idx, valid_idx = next(splitter.split(data, groups=data["user_idx"]))
|
|
216
|
+
train_data, valid_data = data.iloc[train_idx], data.iloc[valid_idx]
|
|
217
|
+
|
|
218
|
+
self.train_user_batch, train_data_loader, _ = self._get_data_loader(train_data)
|
|
219
|
+
self.valid_user_batch, valid_data_loader, _ = self._get_data_loader(valid_data, False)
|
|
220
|
+
|
|
221
|
+
self.logger.debug("Training VAE")
|
|
222
|
+
self.model = VAE(
|
|
223
|
+
item_count=self._item_dim,
|
|
224
|
+
latent_dim=self.latent_dim,
|
|
225
|
+
hidden_dim=self.hidden_dim,
|
|
226
|
+
dropout=self.dropout,
|
|
227
|
+
).to(self.device)
|
|
228
|
+
optimizer = Adam(
|
|
229
|
+
self.model.parameters(),
|
|
230
|
+
lr=self.learning_rate,
|
|
231
|
+
weight_decay=self.l2_reg / self.batch_size_users,
|
|
232
|
+
)
|
|
233
|
+
lr_scheduler = ReduceLROnPlateau(optimizer, factor=self.factor, patience=self.patience)
|
|
234
|
+
|
|
235
|
+
self.train(
|
|
236
|
+
train_data_loader,
|
|
237
|
+
valid_data_loader,
|
|
238
|
+
optimizer,
|
|
239
|
+
lr_scheduler,
|
|
240
|
+
self.epochs,
|
|
241
|
+
"multvae",
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def _loss(self, y_pred, y_true, mu_latent, logvar_latent):
|
|
245
|
+
log_softmax_var = sf.log_softmax(y_pred, dim=1)
|
|
246
|
+
bce = -(log_softmax_var * y_true).sum(dim=1).mean()
|
|
247
|
+
kld = (
|
|
248
|
+
-0.5
|
|
249
|
+
* torch.sum(
|
|
250
|
+
1 + logvar_latent - mu_latent.pow(2) - logvar_latent.exp(),
|
|
251
|
+
dim=1,
|
|
252
|
+
).mean()
|
|
253
|
+
)
|
|
254
|
+
return bce + self.anneal * kld
|
|
255
|
+
|
|
256
|
+
def _batch_pass(self, batch, model):
|
|
257
|
+
full_batch = self.train_user_batch if model.training else self.valid_user_batch
|
|
258
|
+
user_batch = torch.FloatTensor(full_batch[batch[0]].toarray()).to(self.device)
|
|
259
|
+
pred_user_batch, latent_mu, latent_logvar = self.model.forward(user_batch)
|
|
260
|
+
return {
|
|
261
|
+
"y_pred": pred_user_batch,
|
|
262
|
+
"y_true": user_batch,
|
|
263
|
+
"mu_latent": latent_mu,
|
|
264
|
+
"logvar_latent": latent_logvar,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
@staticmethod
|
|
268
|
+
def _predict_pairs_inner(
|
|
269
|
+
model: nn.Module,
|
|
270
|
+
user_idx: int,
|
|
271
|
+
items_np_history: np.ndarray,
|
|
272
|
+
items_np_to_pred: np.ndarray,
|
|
273
|
+
item_count: int,
|
|
274
|
+
cnt: Optional[int] = None,
|
|
275
|
+
) -> SparkDataFrame:
|
|
276
|
+
model.eval()
|
|
277
|
+
with torch.no_grad():
|
|
278
|
+
user_batch = torch.zeros((1, item_count))
|
|
279
|
+
user_batch[0, items_np_history] = 1
|
|
280
|
+
user_recs = sf.softmax(model(user_batch)[0][0].detach(), dim=0)
|
|
281
|
+
if cnt is not None:
|
|
282
|
+
best_item_idx = (torch.argsort(user_recs[items_np_to_pred], descending=True)[:cnt]).numpy()
|
|
283
|
+
items_np_to_pred = items_np_to_pred[best_item_idx]
|
|
284
|
+
return PandasDataFrame(
|
|
285
|
+
{
|
|
286
|
+
"user_idx": np.array(items_np_to_pred.shape[0] * [user_idx]),
|
|
287
|
+
"item_idx": items_np_to_pred,
|
|
288
|
+
"relevance": user_recs[items_np_to_pred],
|
|
289
|
+
}
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def _predict_by_user(
|
|
294
|
+
pandas_df: PandasDataFrame,
|
|
295
|
+
model: nn.Module,
|
|
296
|
+
items_np: np.ndarray,
|
|
297
|
+
k: int,
|
|
298
|
+
item_count: int,
|
|
299
|
+
) -> PandasDataFrame:
|
|
300
|
+
return MultVAE._predict_pairs_inner(
|
|
301
|
+
model=model,
|
|
302
|
+
user_idx=pandas_df["user_idx"][0],
|
|
303
|
+
items_np_history=pandas_df["item_idx"].values,
|
|
304
|
+
items_np_to_pred=items_np,
|
|
305
|
+
item_count=item_count,
|
|
306
|
+
cnt=min(len(pandas_df) + k, len(items_np)),
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
@staticmethod
|
|
310
|
+
def _predict_by_user_pairs(
|
|
311
|
+
pandas_df: PandasDataFrame,
|
|
312
|
+
model: nn.Module,
|
|
313
|
+
item_count: int,
|
|
314
|
+
) -> PandasDataFrame:
|
|
315
|
+
return MultVAE._predict_pairs_inner(
|
|
316
|
+
model=model,
|
|
317
|
+
user_idx=pandas_df["user_idx"][0],
|
|
318
|
+
items_np_history=np.array(pandas_df["item_idx_history"][0]),
|
|
319
|
+
items_np_to_pred=np.array(pandas_df["item_idx_to_pred"][0]),
|
|
320
|
+
item_count=item_count,
|
|
321
|
+
cnt=None,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def _load_model(self, path: str):
|
|
325
|
+
self.model = VAE(
|
|
326
|
+
item_count=self._item_dim,
|
|
327
|
+
latent_dim=self.latent_dim,
|
|
328
|
+
hidden_dim=self.hidden_dim,
|
|
329
|
+
dropout=self.dropout,
|
|
330
|
+
).to(self.device)
|
|
331
|
+
self.model.load_state_dict(torch.load(path))
|
|
332
|
+
self.model.eval()
|