replay-rec 0.20.2__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/nn/sequential_dataset.py +8 -2
  3. replay/experimental/__init__.py +0 -0
  4. replay/experimental/metrics/__init__.py +62 -0
  5. replay/experimental/metrics/base_metric.py +603 -0
  6. replay/experimental/metrics/coverage.py +97 -0
  7. replay/experimental/metrics/experiment.py +175 -0
  8. replay/experimental/metrics/hitrate.py +26 -0
  9. replay/experimental/metrics/map.py +30 -0
  10. replay/experimental/metrics/mrr.py +18 -0
  11. replay/experimental/metrics/ncis_precision.py +31 -0
  12. replay/experimental/metrics/ndcg.py +49 -0
  13. replay/experimental/metrics/precision.py +22 -0
  14. replay/experimental/metrics/recall.py +25 -0
  15. replay/experimental/metrics/rocauc.py +49 -0
  16. replay/experimental/metrics/surprisal.py +90 -0
  17. replay/experimental/metrics/unexpectedness.py +76 -0
  18. replay/experimental/models/__init__.py +50 -0
  19. replay/experimental/models/admm_slim.py +257 -0
  20. replay/experimental/models/base_neighbour_rec.py +200 -0
  21. replay/experimental/models/base_rec.py +1386 -0
  22. replay/experimental/models/base_torch_rec.py +234 -0
  23. replay/experimental/models/cql.py +454 -0
  24. replay/experimental/models/ddpg.py +932 -0
  25. replay/experimental/models/dt4rec/__init__.py +0 -0
  26. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  27. replay/experimental/models/dt4rec/gpt1.py +401 -0
  28. replay/experimental/models/dt4rec/trainer.py +127 -0
  29. replay/experimental/models/dt4rec/utils.py +264 -0
  30. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  31. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  32. replay/experimental/models/hierarchical_recommender.py +331 -0
  33. replay/experimental/models/implicit_wrap.py +131 -0
  34. replay/experimental/models/lightfm_wrap.py +303 -0
  35. replay/experimental/models/mult_vae.py +332 -0
  36. replay/experimental/models/neural_ts.py +986 -0
  37. replay/experimental/models/neuromf.py +406 -0
  38. replay/experimental/models/scala_als.py +293 -0
  39. replay/experimental/models/u_lin_ucb.py +115 -0
  40. replay/experimental/nn/data/__init__.py +1 -0
  41. replay/experimental/nn/data/schema_builder.py +102 -0
  42. replay/experimental/preprocessing/__init__.py +3 -0
  43. replay/experimental/preprocessing/data_preparator.py +839 -0
  44. replay/experimental/preprocessing/padder.py +229 -0
  45. replay/experimental/preprocessing/sequence_generator.py +208 -0
  46. replay/experimental/scenarios/__init__.py +1 -0
  47. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  48. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  49. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  50. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  51. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  52. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  53. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  54. replay/experimental/utils/__init__.py +0 -0
  55. replay/experimental/utils/logger.py +24 -0
  56. replay/experimental/utils/model_handler.py +186 -0
  57. replay/experimental/utils/session_handler.py +44 -0
  58. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
  59. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +62 -7
  60. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
  61. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
  62. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,986 @@
1
+ import os
2
+ from typing import Optional, Union
3
+
4
+ import joblib
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from IPython.display import clear_output
10
+ from pyspark.sql import DataFrame
11
+ from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
12
+ from torch import Tensor, nn
13
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
14
+ from tqdm import tqdm
15
+
16
+ from replay.experimental.models.base_rec import HybridRecommender
17
+ from replay.splitters import TimeSplitter
18
+ from replay.utils.spark_utils import convert2spark
19
+
20
+ pd.options.mode.chained_assignment = None
21
+
22
+
23
+ def cartesian_product(left, right):
24
+ """
25
+ This function computes cartesian product.
26
+ """
27
+ return left.assign(key=1).merge(right.assign(key=1), on="key").drop(columns=["key"])
28
+
29
+
30
+ def num_tries_gt_zero(scores, batch_size, max_trials, max_num, device):
31
+ """
32
+ scores: [batch_size x N] float scores
33
+ returns: [batch_size x 1] the lowest indice per row where scores were first greater than 0. plus 1
34
+ """
35
+ tmp = scores.gt(0).nonzero().t()
36
+ # We offset these values by 1 to look for unset values (zeros) later
37
+ values = tmp[1] + 1
38
+ # Sparse tensors can't be moved with .to() or .cuda() if you want to send in cuda variables first
39
+ if device.type == "cuda":
40
+ tau = torch.cuda.sparse.LongTensor(tmp, values, torch.Size((batch_size, max_trials + 1))).to_dense()
41
+ else:
42
+ tau = torch.sparse.LongTensor(tmp, values, torch.Size((batch_size, max_trials + 1))).to_dense()
43
+ tau[(tau == 0)] += max_num # set all unused indices to be max possible number so its not picked by min() call
44
+
45
+ tries = torch.min(tau, dim=1)[0]
46
+ return tries
47
+
48
+
49
+ def w_log_loss(output, target, device):
50
+ """
51
+ This function computes weighted logistic loss.
52
+ """
53
+ output = torch.nn.functional.sigmoid(output)
54
+ output = torch.clamp(output, min=1e-7, max=1 - 1e-7)
55
+ count_1 = target.sum().item()
56
+ count_0 = target.shape[0] - count_1
57
+ class_count = np.array([count_0, count_1])
58
+ if count_1 == 0 or count_0 == 0: # noqa: SIM108
59
+ weight = np.array([1.0, 1.0])
60
+ else:
61
+ weight = np.max(class_count) / class_count
62
+ weight = Tensor(weight).to(device)
63
+ loss = weight[1] * target * torch.log(output) + weight[0] * (1 - target) * torch.log(1 - output)
64
+ return -loss.mean()
65
+
66
+
67
+ def warp_loss(positive_predictions, negative_predictions, num_labels, device):
68
+ """
69
+ positive_predictions: [batch_size x 1] floats between -1 to 1
70
+ negative_predictions: [batch_size x N] floats between -1 to 1
71
+ num_labels: int total number of labels in dataset (not just the subset you're using for the batch)
72
+ device: pytorch.device
73
+ """
74
+ batch_size, max_trials = negative_predictions.size(0), negative_predictions.size(1)
75
+
76
+ offsets, ones, max_num = (
77
+ torch.arange(0, batch_size, 1).long().to(device) * (max_trials + 1),
78
+ torch.ones(batch_size, 1).float().to(device),
79
+ batch_size * (max_trials + 1),
80
+ )
81
+
82
+ sample_scores = 1 + negative_predictions - positive_predictions
83
+ # Add column of ones so we know when we used all our attempts.
84
+ # This is used for indexing and computing should_count_loss if no real value is above 0
85
+ sample_scores, negative_predictions = (
86
+ torch.cat([sample_scores, ones], dim=1),
87
+ torch.cat([negative_predictions, ones], dim=1),
88
+ )
89
+
90
+ tries = num_tries_gt_zero(sample_scores, batch_size, max_trials, max_num, device)
91
+ attempts, trial_offset = tries.float(), (tries - 1) + offsets
92
+ # Don't count loss if we used max number of attempts
93
+ loss_weights = torch.log(torch.floor((num_labels - 1) / attempts))
94
+ should_count_loss = (attempts <= max_trials).float()
95
+ losses = (
96
+ loss_weights
97
+ * ((1 - positive_predictions.view(-1)) + negative_predictions.view(-1)[trial_offset])
98
+ * should_count_loss
99
+ )
100
+
101
+ return losses.sum()
102
+
103
+
104
+ class SamplerWithReset(SequentialSampler):
105
+ """
106
+ Sampler class for train dataloader.
107
+ """
108
+
109
+ def __iter__(self):
110
+ self.data_source.reset()
111
+ return super().__iter__()
112
+
113
+
114
+ class UserDatasetWithReset(Dataset):
115
+ """
116
+ Dataset class that takes data for a single user and
117
+ column names for continuous data, categorical data and data for
118
+ Wide model as well as the name of the target column.
119
+ The class also supports sampling of negative examples.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ idx,
125
+ log_train,
126
+ user_features,
127
+ item_features,
128
+ list_items,
129
+ union_cols,
130
+ cnt_neg_samples,
131
+ device,
132
+ target: Optional[str] = None,
133
+ ):
134
+ if cnt_neg_samples is not None:
135
+ self.cnt_neg_samples = cnt_neg_samples
136
+ self.user_features = user_features
137
+ self.item_features = item_features
138
+ item_idx_user = log_train["item_idx"].values.tolist()
139
+ self.item_idx_not_user = list(set(list_items).difference(set(item_idx_user)))
140
+ else:
141
+ self.cnt_neg_samples = cnt_neg_samples
142
+ self.user_features = None
143
+ self.item_features = None
144
+ self.item_idx_not_user = None
145
+ self.device = device
146
+ self.union_cols = union_cols
147
+ dataframe = log_train.merge(user_features, on="user_idx", how="inner")
148
+ self.dataframe = dataframe.merge(item_features, on="item_idx", how="inner")
149
+ self.user_idx = idx
150
+ self.data_sample = None
151
+ self.wide_part = Tensor(self.dataframe[self.union_cols["wide_cols"]].to_numpy().astype("float32")).to(
152
+ self.device
153
+ )
154
+ self.continuous_part = Tensor(
155
+ self.dataframe[self.union_cols["continuous_cols"]].to_numpy().astype("float32")
156
+ ).to(self.device)
157
+ self.cat_part = Tensor(self.dataframe[self.union_cols["cat_embed_cols"]].to_numpy().astype("float32")).to(
158
+ self.device
159
+ )
160
+ self.users = Tensor(self.dataframe[["user_idx"]].to_numpy().astype("int")).to(torch.long).to(self.device)
161
+ self.items = Tensor(self.dataframe[["item_idx"]].to_numpy().astype("int")).to(torch.long).to(self.device)
162
+ if target is not None:
163
+ self.target = Tensor(dataframe[target].to_numpy().astype("int")).to(self.device)
164
+ else:
165
+ self.target = target
166
+ self.target_column = target
167
+
168
+ def get_parts(self, data_sample):
169
+ """
170
+ Dataset method that selects user index, item indexes, categorical data,
171
+ continuous data, data for wide model, and target value.
172
+ """
173
+ self.wide_part = Tensor(data_sample[self.union_cols["wide_cols"]].to_numpy().astype("float32")).to(self.device)
174
+ self.continuous_part = Tensor(data_sample[self.union_cols["continuous_cols"]].to_numpy().astype("float32")).to(
175
+ self.device
176
+ )
177
+ self.cat_part = Tensor(data_sample[self.union_cols["cat_embed_cols"]].to_numpy().astype("float32")).to(
178
+ self.device
179
+ )
180
+ self.users = Tensor(data_sample[["user_idx"]].to_numpy().astype("int")).to(torch.long).to(self.device)
181
+ self.items = Tensor(data_sample[["item_idx"]].to_numpy().astype("int")).to(torch.long).to(self.device)
182
+ if self.target_column is not None:
183
+ self.target = Tensor(data_sample[self.target_column].to_numpy().astype("int")).to(self.device)
184
+ else:
185
+ self.target = self.target_column
186
+
187
+ def __getitem__(self, idx):
188
+ target = -1
189
+ if self.target is not None:
190
+ target = self.target[idx]
191
+ return (
192
+ self.wide_part[idx],
193
+ self.continuous_part[idx],
194
+ self.cat_part[idx],
195
+ self.users[idx],
196
+ self.items[idx],
197
+ target,
198
+ )
199
+
200
+ def __len__(self):
201
+ if self.data_sample is not None:
202
+ return self.data_sample.shape[0]
203
+ else:
204
+ return self.dataframe.shape[0]
205
+
206
+ def get_size_features(self):
207
+ """
208
+ Dataset method that gets the size of features after encoding/scaling.
209
+ """
210
+ return self.wide_part.shape[1], self.continuous_part.shape[1], self.cat_part.shape[1]
211
+
212
+ def reset(self):
213
+ """
214
+ Dataset methos that samples new negative examples..
215
+ """
216
+ n_samples = min(len(self.item_idx_not_user), self.cnt_neg_samples)
217
+ if n_samples > 0:
218
+ sample_item = np.random.choice(self.item_idx_not_user, n_samples, replace=False)
219
+ sample_item_feat = self.item_features.loc[self.item_features["item_idx"].isin(sample_item)]
220
+ sample_item_feat = sample_item_feat.set_axis(range(sample_item_feat.shape[0]), axis="index")
221
+ df_sample = cartesian_product(
222
+ self.user_features.loc[self.user_features["user_idx"] == self.user_idx], sample_item_feat
223
+ )
224
+ df_sample[self.target_column] = 0
225
+ self.data_sample = pd.concat([self.dataframe, df_sample], axis=0, ignore_index=True)
226
+ self.get_parts(self.data_sample)
227
+
228
+
229
+ class Wide(nn.Module):
230
+ """
231
+ Wide model based on https://arxiv.org/abs/1606.07792
232
+ """
233
+
234
+ def __init__(self, input_dim: int, out_dim: int = 1):
235
+ super().__init__()
236
+
237
+ self.linear = nn.Sequential(nn.Linear(input_dim, out_dim), nn.ReLU(), nn.BatchNorm1d(out_dim))
238
+ self.out_dim = out_dim
239
+
240
+ def forward(self, input_data):
241
+ """
242
+ :param input_data: wide features
243
+ :return: torch tensor with shape batch_size*out_dim
244
+ """
245
+ output = self.linear(input_data)
246
+ return output
247
+
248
+
249
+ class Deep(nn.Module):
250
+ """
251
+ Deep model based on https://arxiv.org/abs/1606.07792
252
+ """
253
+
254
+ def __init__(self, input_dim: int, out_dim: int, hidden_layers: list[int], deep_dropout: float):
255
+ super().__init__()
256
+ model = []
257
+ last_size = input_dim
258
+ for cur_size in hidden_layers:
259
+ model += [nn.Linear(last_size, cur_size)]
260
+ model += [nn.ReLU()]
261
+ model += [nn.BatchNorm1d(cur_size)]
262
+ model += [nn.Dropout(deep_dropout)]
263
+ last_size = cur_size
264
+ model += [nn.Linear(last_size, out_dim)]
265
+ model += [nn.ReLU()]
266
+ model += [nn.BatchNorm1d(out_dim)]
267
+ model += [nn.Dropout(deep_dropout)]
268
+ self.deep_model = nn.Sequential(*model)
269
+
270
+ def forward(self, input_data):
271
+ """
272
+ :param input_data: deep features
273
+ :return: torch tensor with shape batch_size*out_dim
274
+ """
275
+ output = self.deep_model(input_data)
276
+ return output
277
+
278
+
279
+ class EmbedModel(nn.Module):
280
+ """
281
+ Model for getting embeddings for user and item indexes.
282
+ """
283
+
284
+ def __init__(self, cnt_users: int, cnt_items: int, user_embed: int, item_embed: int, crossed_embed: int):
285
+ super().__init__()
286
+ self.user_embed = nn.Embedding(num_embeddings=cnt_users, embedding_dim=user_embed)
287
+ self.item_embed = nn.Embedding(num_embeddings=cnt_items, embedding_dim=item_embed)
288
+ self.user_crossed_embed = nn.Embedding(num_embeddings=cnt_users, embedding_dim=crossed_embed)
289
+ self.item_crossed_embed = nn.Embedding(num_embeddings=cnt_items, embedding_dim=crossed_embed)
290
+
291
+ def forward(self, users, items):
292
+ """
293
+ :param users: user indexes
294
+ :param items: item indexes
295
+ :return: torch tensors: embedings for users, embedings for items,
296
+ embedings for users for wide model,
297
+ embedings for items for wide model,
298
+ embedings for pairs (users, items) for wide model
299
+ """
300
+ users_to_embed = self.user_embed(users).squeeze()
301
+ items_to_embed = self.item_embed(items).squeeze()
302
+ cross_users = self.user_crossed_embed(users).squeeze()
303
+ cross_items = self.item_crossed_embed(items).squeeze()
304
+ cross = (cross_users * cross_items).sum(dim=-1).unsqueeze(-1)
305
+ return users_to_embed, items_to_embed, cross_users, cross_items, cross
306
+
307
+
308
+ class WideDeep(nn.Module):
309
+ """
310
+ Wide&Deep model based on https://arxiv.org/abs/1606.07792
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ dim_head: int,
316
+ deep_out_dim: int,
317
+ hidden_layers: list[int],
318
+ size_wide_features: int,
319
+ size_continuous_features: int,
320
+ size_cat_features: int,
321
+ wide_out_dim: int,
322
+ head_dropout: float,
323
+ deep_dropout: float,
324
+ cnt_users: int,
325
+ cnt_items: int,
326
+ user_embed: int,
327
+ item_embed: int,
328
+ crossed_embed: int,
329
+ ):
330
+ super().__init__()
331
+ self.embed_model = EmbedModel(cnt_users, cnt_items, user_embed, item_embed, crossed_embed)
332
+ self.wide = Wide(size_wide_features + crossed_embed * 2 + 1, wide_out_dim)
333
+ self.deep = Deep(
334
+ size_cat_features + size_continuous_features + user_embed + item_embed,
335
+ deep_out_dim,
336
+ hidden_layers,
337
+ deep_dropout,
338
+ )
339
+ self.head_model = nn.Sequential(nn.Linear(wide_out_dim + deep_out_dim, dim_head), nn.ReLU())
340
+ self.last_layer = nn.Sequential(nn.Linear(dim_head, 1))
341
+ self.head_dropout = head_dropout
342
+
343
+ def forward_for_predict(self, wide_part, continuous_part, cat_part, users, items):
344
+ """
345
+ Forward method without last layer and dropout that is used for prediction.
346
+ """
347
+ users_to_embed, items_to_embed, cross_users, cross_items, cross = self.embed_model(users, items)
348
+ input_deep = torch.cat((cat_part, continuous_part, users_to_embed, items_to_embed), dim=-1).squeeze()
349
+ out_deep = self.deep(input_deep)
350
+ wide_part = torch.cat((wide_part, cross_users, cross_items, cross), dim=-1)
351
+ out_wide = self.wide(wide_part)
352
+ input_data = torch.cat((out_wide, out_deep), dim=-1)
353
+ output = self.head_model(input_data)
354
+ return output
355
+
356
+ def forward_dropout(self, input_data):
357
+ """
358
+ Forward method for multiple prediction with active dropout
359
+ :param input_data: output of forward_for_predict
360
+ :return: torch tensor after dropout and last linear layer
361
+ """
362
+ output = nn.functional.dropout(input_data, p=self.head_dropout, training=True)
363
+ output = self.last_layer(output)
364
+ return output
365
+
366
+ def forward_for_embeddings(
367
+ self, wide_part, continuous_part, cat_part, users_to_embed, items_to_embed, cross_users, cross_items, cross
368
+ ):
369
+ """
370
+ Forward method after getting emdeddings for users and items.
371
+ """
372
+ input_deep = torch.cat((cat_part, continuous_part, users_to_embed, items_to_embed), dim=-1).squeeze()
373
+ out_deep = self.deep(input_deep)
374
+ wide_part = torch.cat((wide_part, cross_users, cross_items, cross), dim=-1)
375
+ out_wide = self.wide(wide_part)
376
+ input_data = torch.cat((out_wide, out_deep), dim=-1)
377
+ output = self.head_model(input_data)
378
+ output = nn.functional.dropout(output, p=self.head_dropout, training=True)
379
+ output = self.last_layer(output)
380
+ return output
381
+
382
+ def forward(self, wide_part, continuous_part, cat_part, users, items):
383
+ """
384
+ :param wide_part: features for wide model
385
+ :param continuous_part: continuous features
386
+ :param cat_part: torch categorical features
387
+ :param users: user indexes
388
+ :param items: item indexes
389
+ :return: relevances for pair (users, items)
390
+
391
+ """
392
+ users_to_embed, items_to_embed, cross_users, cross_items, cross = self.embed_model(users, items)
393
+ output = self.forward_for_embeddings(
394
+ wide_part, continuous_part, cat_part, users_to_embed, items_to_embed, cross_users, cross_items, cross
395
+ )
396
+ return output
397
+
398
+
399
+ class NeuralTS(HybridRecommender):
400
+ """
401
+ 'Neural Thompson sampling recommender
402
+ <https://dl.acm.org/doi/pdf/10.1145/3383313.3412214>`_ based on `Wide&Deep model
403
+ <https://arxiv.org/abs/1606.07792>`_.
404
+
405
+ :param user_cols: user_cols = {'continuous_cols':List[str], 'cat_embed_cols':List[str], 'wide_cols': List[str]},
406
+ where List[str] -- some column names from user_features dataframe, which is input to the fit method,
407
+ or empty List
408
+ :param item_cols: item_cols = {'continuous_cols':List[str], 'cat_embed_cols':List[str], 'wide_cols': List[str]},
409
+ where List[str] -- some column names from item_features dataframe, which is input to the fit method,
410
+ or empty List
411
+ :param embedding_sizes: list of length three in which
412
+ embedding_sizes[0] = embedding size for users,
413
+ embedding_sizes[1] = embedding size for items,
414
+ embedding_sizes[2] = embedding size for pair (users, items)
415
+ :param hidden_layers: list of hidden layer sizes for Deep model
416
+ :param wide_out_dim: output size for the Wide model
417
+ :param deep_out_dim: output size for the Deep model
418
+ :param head_dropout: probability of an element to be zeroed for WideDeep model head
419
+ :param deep_dropout: probability of an element to be zeroed for Deep model
420
+ :param dim_head: output size for WideDeep model head
421
+ :param n_epochs: number of epochs for model training
422
+ :param opt_lr: learning rate for the AdamW optimizer
423
+ :param lr_min: minimum learning rate value for the CosineAnnealingLR learning rate scheduler
424
+ :param use_gpu: if true, the model will be trained on the GPU
425
+ :param use_warp_loss: if true, then warp loss will be used otherwise weighted logistic loss.
426
+ :param cnt_neg_samples: number of additional negative examples for each user
427
+ :param cnt_samples_for_predict: number of sampled predictions for one user,
428
+ which are used to estimate the mean and variance of relevance
429
+ :param exploration_coef: exploration coefficient
430
+ :param plot_dir: file name where the training graphs will be saved, if None, the graphs will not be saved
431
+ :param cnt_users: number of users, used in Wide&Deep model initialization
432
+ :param cnt_items: number of items, used in Wide&Deep model initialization
433
+
434
+ """
435
+
436
+ def __init__(
437
+ self,
438
+ user_cols: dict[str, list[str]] = {"continuous_cols": [], "cat_embed_cols": [], "wide_cols": []},
439
+ item_cols: dict[str, list[str]] = {"continuous_cols": [], "cat_embed_cols": [], "wide_cols": []},
440
+ embedding_sizes: list[int] = [32, 32, 64],
441
+ hidden_layers: list[int] = [32, 20],
442
+ wide_out_dim: int = 1,
443
+ deep_out_dim: int = 20,
444
+ head_dropout: float = 0.8,
445
+ deep_dropout: float = 0.4,
446
+ dim_head: int = 20,
447
+ n_epochs: int = 2,
448
+ opt_lr: float = 3e-4,
449
+ lr_min: float = 1e-5,
450
+ use_gpu: bool = False,
451
+ use_warp_loss: bool = True,
452
+ cnt_neg_samples: int = 100,
453
+ cnt_samples_for_predict: int = 10,
454
+ exploration_coef: float = 1.0,
455
+ cnt_users: Optional[int] = None,
456
+ cnt_items: Optional[int] = None,
457
+ plot_dir: Optional[str] = None,
458
+ ):
459
+ self.user_cols = user_cols
460
+ self.item_cols = item_cols
461
+ self.embedding_sizes = embedding_sizes
462
+ self.hidden_layers = hidden_layers
463
+ self.wide_out_dim = wide_out_dim
464
+ self.deep_out_dim = deep_out_dim
465
+ self.head_dropout = head_dropout
466
+ self.deep_dropout = deep_dropout
467
+ self.dim_head = dim_head
468
+ self.n_epochs = n_epochs
469
+ self.opt_lr = opt_lr
470
+ self.lr_min = lr_min
471
+ self.device = torch.device("cpu")
472
+ if use_gpu:
473
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
474
+ self.use_warp_loss = use_warp_loss
475
+ self.cnt_neg_samples = cnt_neg_samples
476
+ self.cnt_samples_for_predict = cnt_samples_for_predict
477
+ self.exploration_coef = exploration_coef
478
+ self.cnt_users = cnt_users
479
+ self.cnt_items = cnt_items
480
+ self.plot_dir = plot_dir
481
+
482
+ self.size_wide_features = None
483
+ self.size_continuous_features = None
484
+ self.size_cat_features = None
485
+ self.scaler_user = None
486
+ self.encoder_intersept_user = None
487
+ self.encoder_diff_user = None
488
+ self.scaler_item = None
489
+ self.encoder_intersept_item = None
490
+ self.encoder_diff_item = None
491
+ self.union_cols = None
492
+ self.num_of_train_labels = None
493
+ self.dict_true_items_val = None
494
+ self.lr_scheduler = None
495
+ self.model = None
496
+ self.criterion = None
497
+ self.optimizer = None
498
+
499
+ def preprocess_features_fit(self, train, item_features, user_features):
500
+ """
501
+ This function initializes all ecoders for the features.
502
+ """
503
+ train_users = user_features.loc[user_features["user_idx"].isin(train["user_idx"].values.tolist())]
504
+ wide_cols_cat = list(set(self.user_cols["cat_embed_cols"]) & set(self.user_cols["wide_cols"]))
505
+ cat_embed_cols_not_wide = list(set(self.user_cols["cat_embed_cols"]).difference(set(wide_cols_cat)))
506
+ if len(self.user_cols["continuous_cols"]) != 0:
507
+ self.scaler_user = MinMaxScaler()
508
+ self.scaler_user.fit(train_users[self.user_cols["continuous_cols"]])
509
+ else:
510
+ self.scaler_user = None
511
+ if len(wide_cols_cat) != 0:
512
+ self.encoder_intersept_user = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
513
+ self.encoder_intersept_user.fit(train_users[wide_cols_cat])
514
+ else:
515
+ self.encoder_intersept_user = None
516
+ if len(cat_embed_cols_not_wide) != 0:
517
+ self.encoder_diff_user = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
518
+ self.encoder_diff_user.fit(train_users[cat_embed_cols_not_wide])
519
+ else:
520
+ self.encoder_diff_user = None
521
+ train_items = item_features.loc[item_features["item_idx"].isin(train["item_idx"].values.tolist())]
522
+ wide_cols_cat = list(set(self.item_cols["cat_embed_cols"]) & set(self.item_cols["wide_cols"]))
523
+ cat_embed_cols_not_wide = list(set(self.item_cols["cat_embed_cols"]).difference(set(wide_cols_cat)))
524
+ if len(self.item_cols["continuous_cols"]) != 0:
525
+ self.scaler_item = MinMaxScaler()
526
+ self.scaler_item.fit(train_items[self.item_cols["continuous_cols"]])
527
+ else:
528
+ self.scaler_item = None
529
+ if len(wide_cols_cat) != 0:
530
+ self.encoder_intersept_item = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
531
+ self.encoder_intersept_item.fit(train_items[wide_cols_cat])
532
+ else:
533
+ self.encoder_intersept_item = None
534
+ if len(cat_embed_cols_not_wide) != 0:
535
+ self.encoder_diff_item = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
536
+ self.encoder_diff_item.fit(train_items[cat_embed_cols_not_wide])
537
+ else:
538
+ self.encoder_diff_item = None
539
+
540
+ def preprocess_features_transform(self, item_features, user_features):
541
+ """
542
+ This function performs the transformation for all features.
543
+ """
544
+ self.union_cols = {"continuous_cols": [], "cat_embed_cols": [], "wide_cols": []}
545
+ wide_cols_cat = list(set(self.user_cols["cat_embed_cols"]) & set(self.user_cols["wide_cols"]))
546
+ cat_embed_cols_not_wide = list(set(self.user_cols["cat_embed_cols"]).difference(set(wide_cols_cat)))
547
+ if len(self.user_cols["continuous_cols"]) != 0:
548
+ users_continuous = pd.DataFrame(
549
+ self.scaler_user.transform(user_features[self.user_cols["continuous_cols"]]),
550
+ columns=self.user_cols["continuous_cols"],
551
+ )
552
+ self.union_cols["continuous_cols"] += self.user_cols["continuous_cols"]
553
+ else:
554
+ users_continuous = user_features[[]]
555
+ if len(wide_cols_cat) != 0:
556
+ users_wide_cat = pd.DataFrame(
557
+ self.encoder_intersept_user.transform(user_features[wide_cols_cat]),
558
+ columns=list(self.encoder_intersept_user.get_feature_names_out(wide_cols_cat)),
559
+ )
560
+ self.union_cols["cat_embed_cols"] += list(self.encoder_intersept_user.get_feature_names_out(wide_cols_cat))
561
+ self.union_cols["wide_cols"] += list(
562
+ set(self.user_cols["wide_cols"]).difference(set(self.user_cols["cat_embed_cols"]))
563
+ ) + list(self.encoder_intersept_user.get_feature_names_out(wide_cols_cat))
564
+ else:
565
+ users_wide_cat = user_features[[]]
566
+ if len(cat_embed_cols_not_wide) != 0:
567
+ users_cat = pd.DataFrame(
568
+ self.encoder_diff_user.transform(user_features[cat_embed_cols_not_wide]),
569
+ columns=list(self.encoder_diff_user.get_feature_names_out(cat_embed_cols_not_wide)),
570
+ )
571
+ self.union_cols["cat_embed_cols"] += list(
572
+ self.encoder_diff_user.get_feature_names_out(cat_embed_cols_not_wide)
573
+ )
574
+ else:
575
+ users_cat = user_features[[]]
576
+
577
+ transform_user_features = pd.concat(
578
+ [user_features[["user_idx"]], users_continuous, users_wide_cat, users_cat], axis=1
579
+ )
580
+
581
+ wide_cols_cat = list(set(self.item_cols["cat_embed_cols"]) & set(self.item_cols["wide_cols"]))
582
+ cat_embed_cols_not_wide = list(set(self.item_cols["cat_embed_cols"]).difference(set(wide_cols_cat)))
583
+ if len(self.item_cols["continuous_cols"]) != 0:
584
+ items_continuous = pd.DataFrame(
585
+ self.scaler_item.transform(item_features[self.item_cols["continuous_cols"]]),
586
+ columns=self.item_cols["continuous_cols"],
587
+ )
588
+ self.union_cols["continuous_cols"] += self.item_cols["continuous_cols"]
589
+ else:
590
+ items_continuous = item_features[[]]
591
+ if len(wide_cols_cat) != 0:
592
+ items_wide_cat = pd.DataFrame(
593
+ self.encoder_intersept_item.transform(item_features[wide_cols_cat]),
594
+ columns=list(self.encoder_intersept_item.get_feature_names_out(wide_cols_cat)),
595
+ )
596
+ self.union_cols["cat_embed_cols"] += list(self.encoder_intersept_item.get_feature_names_out(wide_cols_cat))
597
+ self.union_cols["wide_cols"] += list(
598
+ set(self.item_cols["wide_cols"]).difference(set(self.item_cols["cat_embed_cols"]))
599
+ ) + list(self.encoder_intersept_item.get_feature_names_out(wide_cols_cat))
600
+ else:
601
+ items_wide_cat = item_features[[]]
602
+ if len(cat_embed_cols_not_wide) != 0:
603
+ items_cat = pd.DataFrame(
604
+ self.encoder_diff_item.transform(item_features[cat_embed_cols_not_wide]),
605
+ columns=list(self.encoder_diff_item.get_feature_names_out(cat_embed_cols_not_wide)),
606
+ )
607
+ self.union_cols["cat_embed_cols"] += list(
608
+ self.encoder_diff_item.get_feature_names_out(cat_embed_cols_not_wide)
609
+ )
610
+ else:
611
+ items_cat = item_features[[]]
612
+
613
+ transform_item_features = pd.concat(
614
+ [item_features[["item_idx"]], items_continuous, items_wide_cat, items_cat], axis=1
615
+ )
616
+ return transform_user_features, transform_item_features
617
+
618
+ def _data_loader(
619
+ self, idx, log_train, transform_user_features, transform_item_features, list_items, train=False
620
+ ) -> Union[tuple[UserDatasetWithReset, DataLoader], DataLoader]:
621
+ if train:
622
+ train_dataset = UserDatasetWithReset(
623
+ idx=idx,
624
+ log_train=log_train,
625
+ user_features=transform_user_features,
626
+ item_features=transform_item_features,
627
+ list_items=list_items,
628
+ union_cols=self.union_cols,
629
+ cnt_neg_samples=self.cnt_neg_samples,
630
+ device=self.device,
631
+ target="relevance",
632
+ )
633
+ sampler = SamplerWithReset(train_dataset)
634
+ train_dataloader = DataLoader(
635
+ train_dataset, batch_size=log_train.shape[0] + self.cnt_neg_samples, sampler=sampler
636
+ )
637
+ return train_dataset, train_dataloader
638
+ else:
639
+ dataset = UserDatasetWithReset(
640
+ idx=idx,
641
+ log_train=log_train,
642
+ user_features=transform_user_features,
643
+ item_features=transform_item_features,
644
+ list_items=list_items,
645
+ union_cols=self.union_cols,
646
+ cnt_neg_samples=None,
647
+ device=self.device,
648
+ target=None,
649
+ )
650
+ dataloader = DataLoader(dataset, batch_size=log_train.shape[0], shuffle=False)
651
+ return dataloader
652
+
653
+ def _fit(
654
+ self,
655
+ log: DataFrame,
656
+ user_features: Optional[DataFrame] = None,
657
+ item_features: Optional[DataFrame] = None,
658
+ ) -> None:
659
+ if user_features is None:
660
+ msg = "User features are missing for fitting"
661
+ raise ValueError(msg)
662
+ if item_features is None:
663
+ msg = "Item features are missing for fitting"
664
+ raise ValueError(msg)
665
+
666
+ train_spl = TimeSplitter(
667
+ time_threshold=0.2,
668
+ drop_cold_items=True,
669
+ drop_cold_users=True,
670
+ query_column="user_idx",
671
+ item_column="item_idx",
672
+ )
673
+ train, val = train_spl.split(log)
674
+ train = train.drop("timestamp")
675
+ val = val.drop("timestamp")
676
+
677
+ train = train.toPandas()
678
+ val = val.toPandas()
679
+ pd_item_features = item_features.toPandas()
680
+ pd_user_features = user_features.toPandas()
681
+ if self.cnt_users is None:
682
+ self.cnt_users = pd_user_features.shape[0]
683
+ if self.cnt_items is None:
684
+ self.cnt_items = pd_item_features.shape[0]
685
+ self.num_of_train_labels = self.cnt_items
686
+
687
+ self.preprocess_features_fit(train, pd_item_features, pd_user_features)
688
+ transform_user_features, transform_item_features = self.preprocess_features_transform(
689
+ pd_item_features, pd_user_features
690
+ )
691
+
692
+ list_items = pd_item_features["item_idx"].values.tolist()
693
+
694
+ dataloader_train_users = []
695
+ train = train.set_axis(range(train.shape[0]), axis="index")
696
+ train_group_by_users = train.groupby("user_idx")
697
+ for idx, df_train_idx in tqdm(train_group_by_users):
698
+ df_train_idx = df_train_idx.loc[df_train_idx["relevance"] == 1]
699
+ if df_train_idx.shape[0] == 0:
700
+ continue
701
+ df_train_idx = df_train_idx.set_axis(range(df_train_idx.shape[0]), axis="index")
702
+ train_dataset, train_dataloader = self._data_loader(
703
+ idx, df_train_idx, transform_user_features, transform_item_features, list_items, train=True
704
+ )
705
+ dataloader_train_users.append(train_dataloader)
706
+
707
+ dataloader_val_users = []
708
+ self.dict_true_items_val = {}
709
+ transform_item_features.sort_values(by=["item_idx"], inplace=True, ignore_index=True)
710
+ val = val.set_axis(range(val.shape[0]), axis="index")
711
+ val_group_by_users = val.groupby("user_idx")
712
+ for idx, df_val_idx in tqdm(val_group_by_users):
713
+ self.dict_true_items_val[idx] = df_val_idx.loc[(df_val_idx["relevance"] == 1)]["item_idx"].values.tolist()
714
+ df_val = cartesian_product(pd.DataFrame({"user_idx": [idx]}), transform_item_features[["item_idx"]])
715
+ df_val = df_val.set_axis(range(df_val.shape[0]), axis="index")
716
+ dataloader_val_users.append(
717
+ self._data_loader(
718
+ idx, df_val, transform_user_features, transform_item_features, list_items, train=False
719
+ )
720
+ )
721
+
722
+ self.size_wide_features, self.size_continuous_features, self.size_cat_features = (
723
+ train_dataset.get_size_features()
724
+ )
725
+ self.model = WideDeep(
726
+ dim_head=self.dim_head,
727
+ deep_out_dim=self.deep_out_dim,
728
+ hidden_layers=self.hidden_layers,
729
+ size_wide_features=self.size_wide_features,
730
+ size_continuous_features=self.size_continuous_features,
731
+ size_cat_features=self.size_cat_features,
732
+ wide_out_dim=self.wide_out_dim,
733
+ head_dropout=self.head_dropout,
734
+ deep_dropout=self.deep_dropout,
735
+ cnt_users=self.cnt_users,
736
+ cnt_items=self.cnt_items,
737
+ user_embed=self.embedding_sizes[0],
738
+ item_embed=self.embedding_sizes[1],
739
+ crossed_embed=self.embedding_sizes[2],
740
+ )
741
+ if self.use_warp_loss:
742
+ self.criterion = warp_loss
743
+ else:
744
+ self.criterion = w_log_loss
745
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), self.opt_lr)
746
+ self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.n_epochs, self.lr_min)
747
+
748
+ self.train(self.model, dataloader_train_users, dataloader_val_users)
749
+
750
+ def train(self, model, train_dataloader, val_dataloader):
751
+ """
752
+ Run training loop.
753
+ """
754
+ train_losses = []
755
+ val_ndcg = []
756
+ model = model.to(self.device)
757
+ for epoch in range(self.n_epochs):
758
+ train_loss = self._batch_pass(model, train_dataloader)
759
+ ndcg = self.predict_val_with_ndcg(model, val_dataloader, k=10)
760
+ train_losses.append(train_loss)
761
+ val_ndcg.append(ndcg)
762
+
763
+ if self.plot_dir is not None and epoch > 0:
764
+ clear_output(wait=True)
765
+ _, (ax1, ax2) = plt.subplots(1, 2, figsize=(30, 15))
766
+ ax1.plot(train_losses, label="train", color="b")
767
+ ax2.plot(val_ndcg, label="val_ndcg", color="g")
768
+ size = max(1, round(epoch / 10))
769
+ plt.xticks(range(epoch - 1)[::size])
770
+ ax1.set_ylabel("loss")
771
+ ax1.set_xlabel("epoch")
772
+ ax2.set_ylabel("ndcg")
773
+ ax2.set_xlabel("epoch")
774
+ plt.legend()
775
+ plt.savefig(self.plot_dir)
776
+ plt.show()
777
+ self.logger.info("ndcg val =%.4f", ndcg)
778
+
779
+ def _loss(self, preds, labels):
780
+ if self.use_warp_loss:
781
+ ind_pos = torch.where(labels == 1)[0]
782
+ ind_neg = torch.where(labels == 0)[0]
783
+ min_batch = ind_pos.shape[0]
784
+ if ind_pos.shape[0] == 0 or ind_neg.shape[0] == 0:
785
+ return
786
+ indexes_pos = ind_pos
787
+ pos = preds.squeeze()[indexes_pos].unsqueeze(-1)
788
+ list_neg = []
789
+ for _ in range(min_batch):
790
+ indexes_neg = ind_neg[torch.randperm(ind_neg.shape[0])]
791
+ list_neg.append(preds.squeeze()[indexes_neg].unsqueeze(-1))
792
+ neg = torch.cat(list_neg, dim=-1)
793
+ neg = neg.transpose(0, 1)
794
+ loss = self.criterion(pos, neg, self.num_of_train_labels, self.device)
795
+ else:
796
+ loss = self.criterion(preds.squeeze(), labels, self.device)
797
+ return loss
798
+
799
+ def _batch_pass(self, model, train_dataloader):
800
+ """
801
+ Run training one epoch loop.
802
+ """
803
+ model.train()
804
+ idx = 0
805
+ cumulative_loss = 0
806
+ preds = None
807
+ for user_dataloader in tqdm(train_dataloader):
808
+ for batch in user_dataloader:
809
+ wide_part, continuous_part, cat_part, users, items, labels = batch
810
+ self.optimizer.zero_grad()
811
+ preds = model(wide_part, continuous_part, cat_part, users, items)
812
+ loss = self._loss(preds, labels)
813
+ if loss is not None:
814
+ loss.backward()
815
+ self.optimizer.step()
816
+ cumulative_loss += loss.item()
817
+ idx += 1
818
+
819
+ self.lr_scheduler.step()
820
+ return cumulative_loss / idx
821
+
822
+ def predict_val_with_ndcg(self, model, val_dataloader, k):
823
+ """
824
+ This function returns the NDCG metric for the validation data.
825
+ """
826
+ if len(val_dataloader) == 0:
827
+ return 0
828
+
829
+ ndcg = 0
830
+ idx = 0
831
+ model = model.to(self.device)
832
+ for user_dataloader in tqdm(val_dataloader):
833
+ _, _, _, users, _, _ = next(iter(user_dataloader))
834
+ user = int(users[0])
835
+ sample_pred = np.array(self.predict_val(model, user_dataloader))
836
+ top_k_predicts = (-sample_pred).argsort()[:k]
837
+ ndcg += (np.isin(top_k_predicts, self.dict_true_items_val[user]).sum()) / k
838
+ idx += 1
839
+
840
+ metric = ndcg / idx
841
+ return metric
842
+
843
+ def predict_val(self, model, val_dataloader):
844
+ """
845
+ This function returns the relevances for the validation data.
846
+ """
847
+ probs = []
848
+ model = model.to(self.device)
849
+ model.eval()
850
+ with torch.no_grad():
851
+ for wide_part, continuous_part, cat_part, users, items, _ in val_dataloader:
852
+ preds = model(wide_part, continuous_part, cat_part, users, items)
853
+ probs += (preds.squeeze()).tolist()
854
+ return probs
855
+
856
+ def predict_test(self, model, test_dataloader, cnt_samples_for_predict):
857
+ """
858
+ This function returns a list of cnt_samples_for_predict relevancies for each pair (users, items)
859
+ in val_dataloader
860
+ """
861
+ probs = []
862
+ model = model.to(self.device)
863
+ model.eval()
864
+ with torch.no_grad():
865
+ for wide_part, continuous_part, cat_part, users, items, _ in test_dataloader:
866
+ preds = model.forward_for_predict(wide_part, continuous_part, cat_part, users, items)
867
+ probs.extend(model.forward_dropout(preds).squeeze().tolist() for __ in range(cnt_samples_for_predict))
868
+ return probs
869
+
870
+ def _predict(
871
+ self,
872
+ log: DataFrame, # noqa: ARG002
873
+ k: int, # noqa: ARG002
874
+ users: DataFrame,
875
+ items: DataFrame,
876
+ user_features: Optional[DataFrame] = None,
877
+ item_features: Optional[DataFrame] = None,
878
+ filter_seen_items: bool = True, # noqa: ARG002
879
+ ) -> DataFrame:
880
+ if user_features is None:
881
+ msg = "User features are missing for predict"
882
+ raise ValueError(msg)
883
+ if item_features is None:
884
+ msg = "Item features are missing for predict"
885
+ raise ValueError(msg)
886
+
887
+ pd_users = users.toPandas()
888
+ pd_items = items.toPandas()
889
+ pd_user_features = user_features.toPandas()
890
+ pd_item_features = item_features.toPandas()
891
+
892
+ list_items = pd_item_features["item_idx"].values.tolist()
893
+
894
+ transform_user_features, transform_item_features = self.preprocess_features_transform(
895
+ pd_item_features, pd_user_features
896
+ )
897
+
898
+ preds = []
899
+ users_ans = []
900
+ items_ans = []
901
+ for idx in tqdm(pd_users["user_idx"].unique()):
902
+ df_test_idx = cartesian_product(pd.DataFrame({"user_idx": [idx]}), pd_items)
903
+ df_test_idx = df_test_idx.set_axis(range(df_test_idx.shape[0]), axis="index")
904
+ test_dataloader = self._data_loader(
905
+ idx, df_test_idx, transform_user_features, transform_item_features, list_items, train=False
906
+ )
907
+
908
+ samples = np.array(self.predict_test(self.model, test_dataloader, self.cnt_samples_for_predict))
909
+ sample_pred = np.mean(samples, axis=0) + self.exploration_coef * np.sqrt(np.var(samples, axis=0))
910
+
911
+ preds += sample_pred.tolist()
912
+ users_ans += [idx] * df_test_idx.shape[0]
913
+ items_ans += df_test_idx["item_idx"].values.tolist()
914
+
915
+ res_df = pd.DataFrame({"user_idx": users_ans, "item_idx": items_ans, "relevance": preds})
916
+ pred = convert2spark(res_df)
917
+ return pred
918
+
919
+ @property
920
+ def _init_args(self):
921
+ return {
922
+ "n_epochs": self.n_epochs,
923
+ "union_cols": self.union_cols,
924
+ "cnt_users": self.cnt_users,
925
+ "cnt_items": self.cnt_items,
926
+ "size_wide_features": self.size_wide_features,
927
+ "size_continuous_features": self.size_continuous_features,
928
+ "size_cat_features": self.size_cat_features,
929
+ }
930
+
931
+ def model_save(self, dir_name):
932
+ """
933
+ This function saves the model.
934
+ """
935
+ os.makedirs(dir_name, exist_ok=True)
936
+
937
+ joblib.dump(self.scaler_user, os.path.join(dir_name, "scaler_user.joblib"))
938
+ joblib.dump(self.encoder_intersept_user, os.path.join(dir_name, "encoder_intersept_user.joblib"))
939
+ joblib.dump(self.encoder_diff_user, os.path.join(dir_name, "encoder_diff_user.joblib"))
940
+
941
+ joblib.dump(self.scaler_item, os.path.join(dir_name, "scaler_item.joblib"))
942
+ joblib.dump(self.encoder_intersept_item, os.path.join(dir_name, "encoder_intersept_item.joblib"))
943
+ joblib.dump(self.encoder_diff_item, os.path.join(dir_name, "encoder_diff_item.joblib"))
944
+
945
+ torch.save(self.model.state_dict(), os.path.join(dir_name, "model_weights.pth"))
946
+ torch.save(
947
+ {
948
+ "fit_users": self.fit_users.toPandas(),
949
+ "fit_items": self.fit_items.toPandas(),
950
+ },
951
+ os.path.join(dir_name, "fit_info.pth"),
952
+ )
953
+
954
+ def model_load(self, dir_name):
955
+ """
956
+ This function loads the model.
957
+ """
958
+ self.scaler_user = joblib.load(os.path.join(dir_name, "scaler_user.joblib"))
959
+ self.encoder_intersept_user = joblib.load(os.path.join(dir_name, "encoder_intersept_user.joblib"))
960
+ self.encoder_diff_user = joblib.load(os.path.join(dir_name, "encoder_diff_user.joblib"))
961
+
962
+ self.scaler_item = joblib.load(os.path.join(dir_name, "scaler_item.joblib"))
963
+ self.encoder_intersept_item = joblib.load(os.path.join(dir_name, "encoder_intersept_item.joblib"))
964
+ self.encoder_diff_item = joblib.load(os.path.join(dir_name, "encoder_diff_item.joblib"))
965
+
966
+ self.model = WideDeep(
967
+ dim_head=self.dim_head,
968
+ deep_out_dim=self.deep_out_dim,
969
+ hidden_layers=self.hidden_layers,
970
+ size_wide_features=self.size_wide_features,
971
+ size_continuous_features=self.size_continuous_features,
972
+ size_cat_features=self.size_cat_features,
973
+ wide_out_dim=self.wide_out_dim,
974
+ head_dropout=self.head_dropout,
975
+ deep_dropout=self.deep_dropout,
976
+ cnt_users=self.cnt_users,
977
+ cnt_items=self.cnt_items,
978
+ user_embed=self.embedding_sizes[0],
979
+ item_embed=self.embedding_sizes[1],
980
+ crossed_embed=self.embedding_sizes[2],
981
+ )
982
+ self.model.load_state_dict(torch.load(os.path.join(dir_name, "model_weights.pth")))
983
+
984
+ checkpoint = torch.load(os.path.join(dir_name, "fit_info.pth"), weights_only=False)
985
+ self.fit_users = convert2spark(checkpoint["fit_users"])
986
+ self.fit_items = convert2spark(checkpoint["fit_items"])