replay-rec 0.18.0rc0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/dataset.py +27 -1
  3. replay/data/dataset_utils/dataset_label_encoder.py +6 -3
  4. replay/data/nn/schema.py +37 -16
  5. replay/data/nn/sequence_tokenizer.py +313 -165
  6. replay/data/nn/torch_sequential_dataset.py +17 -8
  7. replay/data/nn/utils.py +14 -7
  8. replay/data/schema.py +10 -6
  9. replay/metrics/offline_metrics.py +2 -2
  10. replay/models/__init__.py +1 -0
  11. replay/models/base_rec.py +18 -21
  12. replay/models/lin_ucb.py +407 -0
  13. replay/models/nn/sequential/bert4rec/dataset.py +17 -4
  14. replay/models/nn/sequential/bert4rec/lightning.py +121 -54
  15. replay/models/nn/sequential/bert4rec/model.py +21 -0
  16. replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
  17. replay/models/nn/sequential/compiled/__init__.py +5 -0
  18. replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
  19. replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
  20. replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
  21. replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
  22. replay/models/nn/sequential/sasrec/dataset.py +17 -1
  23. replay/models/nn/sequential/sasrec/lightning.py +126 -50
  24. replay/models/nn/sequential/sasrec/model.py +3 -4
  25. replay/preprocessing/__init__.py +7 -1
  26. replay/preprocessing/discretizer.py +719 -0
  27. replay/preprocessing/label_encoder.py +384 -52
  28. replay/splitters/cold_user_random_splitter.py +1 -1
  29. replay/utils/__init__.py +1 -0
  30. replay/utils/common.py +7 -8
  31. replay/utils/session_handler.py +3 -4
  32. replay/utils/spark_utils.py +15 -1
  33. replay/utils/types.py +8 -0
  34. {replay_rec-0.18.0rc0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +75 -70
  35. {replay_rec-0.18.0rc0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -84
  36. {replay_rec-0.18.0rc0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +1 -1
  37. replay/experimental/__init__.py +0 -0
  38. replay/experimental/metrics/__init__.py +0 -62
  39. replay/experimental/metrics/base_metric.py +0 -602
  40. replay/experimental/metrics/coverage.py +0 -97
  41. replay/experimental/metrics/experiment.py +0 -175
  42. replay/experimental/metrics/hitrate.py +0 -26
  43. replay/experimental/metrics/map.py +0 -30
  44. replay/experimental/metrics/mrr.py +0 -18
  45. replay/experimental/metrics/ncis_precision.py +0 -31
  46. replay/experimental/metrics/ndcg.py +0 -49
  47. replay/experimental/metrics/precision.py +0 -22
  48. replay/experimental/metrics/recall.py +0 -25
  49. replay/experimental/metrics/rocauc.py +0 -49
  50. replay/experimental/metrics/surprisal.py +0 -90
  51. replay/experimental/metrics/unexpectedness.py +0 -76
  52. replay/experimental/models/__init__.py +0 -10
  53. replay/experimental/models/admm_slim.py +0 -205
  54. replay/experimental/models/base_neighbour_rec.py +0 -204
  55. replay/experimental/models/base_rec.py +0 -1271
  56. replay/experimental/models/base_torch_rec.py +0 -234
  57. replay/experimental/models/cql.py +0 -454
  58. replay/experimental/models/ddpg.py +0 -923
  59. replay/experimental/models/dt4rec/__init__.py +0 -0
  60. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  61. replay/experimental/models/dt4rec/gpt1.py +0 -401
  62. replay/experimental/models/dt4rec/trainer.py +0 -127
  63. replay/experimental/models/dt4rec/utils.py +0 -265
  64. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  65. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  66. replay/experimental/models/implicit_wrap.py +0 -131
  67. replay/experimental/models/lightfm_wrap.py +0 -302
  68. replay/experimental/models/mult_vae.py +0 -332
  69. replay/experimental/models/neuromf.py +0 -406
  70. replay/experimental/models/scala_als.py +0 -296
  71. replay/experimental/nn/data/__init__.py +0 -1
  72. replay/experimental/nn/data/schema_builder.py +0 -55
  73. replay/experimental/preprocessing/__init__.py +0 -3
  74. replay/experimental/preprocessing/data_preparator.py +0 -839
  75. replay/experimental/preprocessing/padder.py +0 -229
  76. replay/experimental/preprocessing/sequence_generator.py +0 -208
  77. replay/experimental/scenarios/__init__.py +0 -1
  78. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  79. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  80. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
  81. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  82. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  83. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  84. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  85. replay/experimental/utils/__init__.py +0 -0
  86. replay/experimental/utils/logger.py +0 -24
  87. replay/experimental/utils/model_handler.py +0 -186
  88. replay/experimental/utils/session_handler.py +0 -44
  89. replay_rec-0.18.0rc0.dist-info/NOTICE +0 -41
  90. {replay_rec-0.18.0rc0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
@@ -0,0 +1,407 @@
1
+ import warnings
2
+ from typing import List, Tuple, Union
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.sparse as scs
7
+ from tqdm import tqdm
8
+
9
+ from replay.data.dataset import Dataset
10
+ from replay.utils import SparkDataFrame
11
+ from replay.utils.spark_utils import convert2spark
12
+
13
+ from .base_rec import HybridRecommender
14
+
15
+
16
+ class DisjointArm:
17
+ """
18
+ Object for interactions with a single arm in a disjoint LinUCB framework.
19
+
20
+ In disjoint LinUCB features of all arms are disjoint (no common features used).
21
+ """
22
+
23
+ def __init__(self, arm_index, d, eps, alpha):
24
+ # Track arm index
25
+ self.arm_index = arm_index
26
+ # Exploration parameter
27
+ self.eps = eps
28
+ self.alpha = alpha
29
+ # Inverse of feature matrix for ridge regression
30
+ self.A = self.alpha * np.identity(d)
31
+ self.A_inv = (1.0 / self.alpha) * np.identity(d)
32
+ # right-hand side of the regression
33
+ self.theta = np.zeros(d, dtype=float)
34
+
35
+ def feature_update(self, usr_features, relevances):
36
+ """
37
+ Function to update featurs or each Lin-UCB hand in the current model.
38
+
39
+ features:
40
+ usr_features = matrix (np.array of shape (m,d)),
41
+ where m = number of occurences of the current feature in the dataset;
42
+ usr_features[i] = features of i-th user, who rated this particular arm (movie);
43
+ relevances = np.array(d) - rating of i-th user, who rated this particular arm (movie);
44
+ """
45
+ # Update A which is (d * d) matrix.
46
+ self.A += np.dot(usr_features.T, usr_features)
47
+ self.A_inv = np.linalg.inv(self.A)
48
+ # Update the parameter theta by the results linear regression
49
+ self.theta = np.linalg.lstsq(self.A, usr_features.T @ relevances, rcond=1.0)[0]
50
+
51
+
52
+ class HybridArm:
53
+ """
54
+ Object for interactions with a single arm in a hybrid LinUCB framework.
55
+
56
+ Hybrid LinUCB combines shared and arm-specific features.
57
+ Preferrable when there are meaningful relationships between different arms e.g. genres, product categories, etc.
58
+ """
59
+
60
+ def __init__(self, arm_index, d, k, eps, alpha):
61
+ # Track arm index
62
+ self.arm_index = arm_index
63
+ # Exploration parameter
64
+ self.eps = eps
65
+ self.alpha = alpha
66
+ # Inverse of feature matrix for ridge regression
67
+ self.A = scs.csr_matrix(self.alpha * np.identity(d))
68
+ self.A_inv = scs.csr_matrix((1.0 / self.alpha) * np.identity(d))
69
+ self.B = scs.csr_matrix(np.zeros((d, k)))
70
+ # right-hand side of the regression
71
+ self.b = np.zeros(d, dtype=float)
72
+
73
+ def feature_update(self, usr_features, usr_itm_features, relevances) -> Tuple[np.ndarray, np.ndarray]:
74
+ """
75
+ Function to update featurs or each Lin-UCB hand in the current model.
76
+
77
+ features:
78
+ usr_features = matrix (np.array of shape (m,d)),
79
+ where m = number of occurences of the current feature in the dataset;
80
+ usr_features[i] = features of i-th user, who rated this particular arm (movie);
81
+ relevances = np.array(d) - rating of i-th user, who rated this particular arm (movie);
82
+ """
83
+
84
+ self.A += (usr_features.T).dot(usr_features)
85
+ self.A_inv = scs.linalg.inv(self.A)
86
+ self.B += (usr_features.T).dot(usr_itm_features)
87
+ self.b += (usr_features.T).dot(relevances)
88
+ delta_A_0 = np.dot(usr_itm_features.T, usr_itm_features) - self.B.T @ self.A_inv @ self.B # noqa: N806
89
+ delta_b_0 = (usr_itm_features.T).dot(relevances) - (self.B.T).dot(self.A_inv.dot(self.b))
90
+ return delta_A_0, delta_b_0
91
+
92
+
93
+ class LinUCB(HybridRecommender):
94
+ """
95
+ A recommender algorithm for contextual bandit problems.
96
+
97
+ Implicitly proposed by `Li et al <https://arxiv.org/pdf/1003.0146>`_.
98
+ The model assumes a linear relationship between user context, item features and action rewards,
99
+ making it efficient for high-dimensional contexts.
100
+
101
+ Note:
102
+ It's recommended to scale features to a similar range (e.g., using StandardScaler or MinMaxScaler)
103
+ to ensure proper convergence and prevent numerical instability (since relationships to learn are linear).
104
+
105
+ >>> import pandas as pd
106
+ >>> from replay.data.dataset import (
107
+ ... Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
108
+ ... )
109
+ >>> interactions = pd.DataFrame({"user_id": [0, 1, 2, 2], "item_id": [0, 1, 0, 1], "rating": [1, 0, 0, 0]})
110
+ >>> user_features = pd.DataFrame(
111
+ ... {"user_id": [0, 1, 2], "usr_feat_1": [1, 2, 3], "usr_feat_2": [4, 5, 6], "usr_feat_3": [7, 8, 9]}
112
+ ... )
113
+ >>> item_features = pd.DataFrame(
114
+ ... {
115
+ ... "item_id": [0, 1, 2, 3, 4, 5],
116
+ ... "itm_feat_1": [1, 2, 3, 4, 5, 6],
117
+ ... "itm_feat_2": [7, 8, 9, 10, 11, 12],
118
+ ... "itm_feat_3": [13, 14, 15, 16, 17, 18]
119
+ ... }
120
+ ... )
121
+ >>> feature_schema = FeatureSchema(
122
+ ... [
123
+ ... FeatureInfo(
124
+ ... column="user_id",
125
+ ... feature_type=FeatureType.CATEGORICAL,
126
+ ... feature_hint=FeatureHint.QUERY_ID,
127
+ ... ),
128
+ ... FeatureInfo(
129
+ ... column="item_id",
130
+ ... feature_type=FeatureType.CATEGORICAL,
131
+ ... feature_hint=FeatureHint.ITEM_ID,
132
+ ... ),
133
+ ... FeatureInfo(
134
+ ... column="rating",
135
+ ... feature_type=FeatureType.NUMERICAL,
136
+ ... feature_hint=FeatureHint.RATING,
137
+ ... ),
138
+ ... *[
139
+ ... FeatureInfo(
140
+ ... column=name, feature_type=FeatureType.NUMERICAL, feature_source=FeatureSource.ITEM_FEATURES,
141
+ ... )
142
+ ... for name in ["itm_feat_1", "itm_feat_2", "itm_feat_3"]
143
+ ... ],
144
+ ... *[
145
+ ... FeatureInfo(
146
+ ... column=name, feature_type=FeatureType.NUMERICAL, feature_source=FeatureSource.QUERY_FEATURES
147
+ ... )
148
+ ... for name in ["usr_feat_1", "usr_feat_2", "usr_feat_3"]
149
+ ... ],
150
+ ... ]
151
+ ... )
152
+ >>> dataset = Dataset(
153
+ ... feature_schema=feature_schema,
154
+ ... interactions=interactions,
155
+ ... item_features=item_features,
156
+ ... query_features=user_features,
157
+ ... categorical_encoded=True,
158
+ ... )
159
+ >>> dataset.to_spark()
160
+ >>> model = LinUCB(eps=-10.0, alpha=1.0, is_hybrid=False)
161
+ >>> model.fit(dataset)
162
+ >>> model.predict(dataset, k=2, queries=[0,1,2]).toPandas().sort_values(["user_id","rating","item_id"],
163
+ ... ascending=[True,False,True]).reset_index(drop=True)
164
+ user_id item_id rating
165
+ 0 0 1 -11.073741
166
+ 1 0 2 -81.240384
167
+ 2 1 0 -6.555529
168
+ 3 1 2 -96.436508
169
+ 4 2 2 -112.249722
170
+ 5 2 3 -112.249722
171
+
172
+ """
173
+
174
+ _search_space = {
175
+ "eps": {"type": "uniform", "args": [-10.0, 10.0]},
176
+ "alpha": {"type": "uniform", "args": [0.001, 10.0]},
177
+ }
178
+ _study = None # field required for proper optuna's optimization
179
+ linucb_arms: List[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
180
+ rel_matrix: np.array # matrix with relevance scores from predict method
181
+
182
+ def __init__(
183
+ self,
184
+ eps: float,
185
+ alpha: float = 1.0,
186
+ is_hybrid: bool = False,
187
+ ):
188
+ """
189
+ :param eps: exploration coefficient
190
+ :param alpha: ridge parameter
191
+ :param is_hybrid: flag to choose model type. If True, model is hybrid.
192
+ """
193
+ self.is_hybrid = is_hybrid
194
+ self.eps = eps
195
+ self.alpha = alpha
196
+
197
+ @property
198
+ def _init_args(self):
199
+ return {"is_hybrid": self.is_hybrid}
200
+
201
+ def _verify_features(self, dataset: Dataset):
202
+ if dataset.query_features is None:
203
+ msg = "User features are missing"
204
+ raise ValueError(msg)
205
+ if dataset.item_features is None:
206
+ msg = "Item features are missing"
207
+ raise ValueError(msg)
208
+ if (
209
+ len(dataset.feature_schema.query_features.categorical_features) > 0
210
+ or len(dataset.feature_schema.item_features.categorical_features) > 0
211
+ ):
212
+ msg = "Categorical features are not supported"
213
+ raise ValueError(msg)
214
+
215
+ def _fit(
216
+ self,
217
+ dataset: Dataset,
218
+ ) -> None:
219
+ self._verify_features(dataset)
220
+
221
+ if not dataset.is_pandas:
222
+ warn_msg = "Dataset will be converted to pandas during internal calculations in fit"
223
+ warnings.warn(warn_msg)
224
+ dataset.to_pandas()
225
+
226
+ feature_schema = dataset.feature_schema
227
+ log = dataset.interactions
228
+ user_features = dataset.query_features
229
+ item_features = dataset.item_features
230
+
231
+ self._num_items = item_features.shape[0]
232
+ self._user_dim_size = user_features.shape[1] - 1
233
+ self._item_dim_size = item_features.shape[1] - 1
234
+
235
+ # now initialize an arm object for each potential arm instance
236
+ if self.is_hybrid:
237
+ hybrid_features_k = self._user_dim_size * self._item_dim_size
238
+ self.A_0 = scs.csr_matrix(np.identity(hybrid_features_k))
239
+ self.b_0 = np.zeros(hybrid_features_k, dtype=float)
240
+ self.linucb_arms = [
241
+ HybridArm(
242
+ arm_index=i,
243
+ d=self._user_dim_size,
244
+ k=hybrid_features_k,
245
+ eps=self.eps,
246
+ alpha=self.alpha,
247
+ )
248
+ for i in range(self._num_items)
249
+ ]
250
+
251
+ for i in tqdm(range(self._num_items)):
252
+ B = log.loc[log[feature_schema.item_id_column] == i] # noqa: N806
253
+ idxs_list = B[feature_schema.query_id_column].values
254
+ rel_list = B[feature_schema.interactions_rating_column].values
255
+ if not B.empty:
256
+ # if we have at least one user interacting with the hand i
257
+ cur_usrs = scs.csr_matrix(
258
+ user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
259
+ .drop(columns=[feature_schema.query_id_column])
260
+ .to_numpy()
261
+ )
262
+ cur_itm = scs.csr_matrix(
263
+ item_features.iloc[i].drop(labels=[feature_schema.item_id_column]).to_numpy()
264
+ )
265
+ usr_itm_features = scs.kron(cur_usrs, cur_itm)
266
+ delta_A_0, delta_b_0 = self.linucb_arms[i].feature_update( # noqa: N806
267
+ cur_usrs, usr_itm_features, rel_list
268
+ )
269
+
270
+ self.A_0 += delta_A_0
271
+ self.b_0 += delta_b_0
272
+
273
+ self.beta = scs.linalg.spsolve(self.A_0, self.b_0)
274
+ self.A_0_inv = scs.linalg.inv(self.A_0)
275
+
276
+ for i in range(self._num_items):
277
+ self.linucb_arms[i].theta = scs.linalg.spsolve(
278
+ self.linucb_arms[i].A,
279
+ self.linucb_arms[i].b - self.linucb_arms[i].B @ self.beta,
280
+ )
281
+ else:
282
+ self.linucb_arms = [
283
+ DisjointArm(arm_index=i, d=self._user_dim_size, eps=self.eps, alpha=self.alpha)
284
+ for i in range(self._num_items)
285
+ ]
286
+
287
+ for i in range(self._num_items):
288
+ B = log.loc[log[feature_schema.item_id_column] == i] # noqa: N806
289
+ idxs_list = B[feature_schema.query_id_column].values # noqa: F841
290
+ rel_list = B[feature_schema.interactions_rating_column].values
291
+ if not B.empty:
292
+ # if we have at least one user interacting with the hand i
293
+ cur_usrs = user_features.query(f"{feature_schema.query_id_column} in @idxs_list").drop(
294
+ columns=[feature_schema.query_id_column]
295
+ )
296
+ self.linucb_arms[i].feature_update(cur_usrs.to_numpy(), rel_list)
297
+
298
+ warn_msg = "Dataset will be converted to spark after internal calculations in fit"
299
+ warnings.warn(warn_msg)
300
+ dataset.to_spark()
301
+
302
+ def _predict(
303
+ self,
304
+ dataset: Dataset,
305
+ k: int,
306
+ users: SparkDataFrame,
307
+ items: SparkDataFrame = None,
308
+ filter_seen_items: bool = True, # noqa: ARG002
309
+ oversample: int = 20,
310
+ ) -> SparkDataFrame:
311
+ self._verify_features(dataset)
312
+
313
+ if not dataset.is_pandas:
314
+ warn_msg = "Dataset will be converted to pandas during internal calculations in predict"
315
+ warnings.warn(warn_msg)
316
+ dataset.to_pandas()
317
+
318
+ feature_schema = dataset.feature_schema
319
+ user_features = dataset.query_features
320
+ item_features = dataset.item_features
321
+ big_k = min(oversample * k, item_features.shape[0])
322
+
323
+ users = users.toPandas()
324
+ num_user_pred = users.shape[0]
325
+ rel_matrix = np.zeros((num_user_pred, self._num_items), dtype=float)
326
+
327
+ if self.is_hybrid:
328
+ items = items.toPandas()
329
+ usr_idxs_list = users[feature_schema.query_id_column].values
330
+ itm_idxs_list = items[feature_schema.item_id_column].values # noqa: F841
331
+
332
+ usrs_feat = scs.csr_matrix(
333
+ user_features.query(f"{feature_schema.query_id_column} in @usr_idxs_list")
334
+ .drop(columns=[feature_schema.query_id_column])
335
+ .to_numpy()
336
+ )
337
+ itm_feat = scs.csr_matrix(
338
+ item_features.query(f"{feature_schema.item_id_column} in @itm_idxs_list")
339
+ .drop(columns=[feature_schema.item_id_column])
340
+ .to_numpy()
341
+ )
342
+
343
+ # fill in relevance matrix
344
+ for i in tqdm(range(self._num_items)):
345
+ z = scs.kron(usrs_feat, itm_feat[i])
346
+ rel_matrix[:, i] = usrs_feat.dot(self.linucb_arms[i].theta)
347
+ rel_matrix[:, i] += z.dot(self.beta)
348
+
349
+ s = (usrs_feat.dot(self.linucb_arms[i].A_inv).multiply(usrs_feat)).sum(axis=1)
350
+ s += (z.dot(self.A_0_inv).multiply(z)).sum(axis=1)
351
+ M = self.A_0_inv @ self.linucb_arms[i].B.T @ self.linucb_arms[i].A_inv # noqa: N806
352
+ s -= 2 * (z.dot(M).multiply(usrs_feat)).sum(axis=1)
353
+ s += (usrs_feat.dot(M.T @ self.linucb_arms[i].B.T @ self.linucb_arms[i].A_inv).multiply(usrs_feat)).sum(
354
+ axis=1
355
+ )
356
+
357
+ rel_matrix[:, i] += np.array(self.eps * np.sqrt(s))[:, 0]
358
+
359
+ # select top k predictions from each row (unsorted ones)
360
+ topk_indices = np.argpartition(rel_matrix, -big_k, axis=1)[:, -big_k:]
361
+ rows_inds, _ = np.indices((num_user_pred, big_k))
362
+ # result df
363
+ predict_inds = np.repeat(usr_idxs_list, big_k)
364
+ predict_items = topk_indices.ravel()
365
+ predict_rels = rel_matrix[rows_inds, topk_indices].ravel()
366
+ # return everything in a PySpark template
367
+ res_df = pd.DataFrame(
368
+ {
369
+ feature_schema.query_id_column: predict_inds,
370
+ feature_schema.item_id_column: predict_items,
371
+ feature_schema.interactions_rating_column: predict_rels,
372
+ }
373
+ )
374
+
375
+ else:
376
+ idxs_list = users[feature_schema.query_id_column].values
377
+ usrs_feat = (
378
+ user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
379
+ .drop(columns=[feature_schema.query_id_column])
380
+ .to_numpy()
381
+ )
382
+ # fill in relevance matrix
383
+ for i in range(self._num_items):
384
+ rel_matrix[:, i] = (
385
+ self.eps * np.sqrt((usrs_feat.dot(self.linucb_arms[i].A_inv) * usrs_feat).sum(axis=1))
386
+ + usrs_feat @ self.linucb_arms[i].theta
387
+ )
388
+ # select top k predictions from each row (unsorted ones)
389
+ topk_indices = np.argpartition(rel_matrix, -big_k, axis=1)[:, -big_k:]
390
+ rows_inds, _ = np.indices((num_user_pred, big_k))
391
+ # result df
392
+ predict_inds = np.repeat(idxs_list, big_k)
393
+ predict_items = topk_indices.ravel()
394
+ predict_rels = rel_matrix[rows_inds, topk_indices].ravel()
395
+ # return everything in a PySpark template
396
+ res_df = pd.DataFrame(
397
+ {
398
+ feature_schema.query_id_column: predict_inds,
399
+ feature_schema.item_id_column: predict_items,
400
+ feature_schema.interactions_rating_column: predict_rels,
401
+ }
402
+ )
403
+
404
+ warn_msg = "Dataset will be converted to spark after internal calculations in predict"
405
+ warnings.warn(warn_msg)
406
+ dataset.to_spark()
407
+ return convert2spark(res_df)
@@ -12,6 +12,7 @@ from replay.data.nn import (
12
12
  TorchSequentialDataset,
13
13
  TorchSequentialValidationDataset,
14
14
  )
15
+ from replay.utils.model_handler import deprecation_warning
15
16
 
16
17
 
17
18
  class Bert4RecTrainingBatch(NamedTuple):
@@ -88,6 +89,10 @@ class Bert4RecTrainingDataset(TorchDataset):
88
89
  Dataset that generates samples to train BERT-like model
89
90
  """
90
91
 
92
+ @deprecation_warning(
93
+ "`padding_value` parameter will be removed in future versions. "
94
+ "Instead, you should specify `padding_value` for each column in TensorSchema"
95
+ )
91
96
  def __init__(
92
97
  self,
93
98
  sequential: SequentialDataset,
@@ -176,6 +181,10 @@ class Bert4RecPredictionDataset(TorchDataset):
176
181
  Dataset that generates samples to infer BERT-like model
177
182
  """
178
183
 
184
+ @deprecation_warning(
185
+ "`padding_value` parameter will be removed in future versions. "
186
+ "Instead, you should specify `padding_value` for each column in TensorSchema"
187
+ )
179
188
  def __init__(
180
189
  self,
181
190
  sequential: SequentialDataset,
@@ -230,6 +239,10 @@ class Bert4RecValidationDataset(TorchDataset):
230
239
  Dataset that generates samples to infer and validate BERT-like model
231
240
  """
232
241
 
242
+ @deprecation_warning(
243
+ "`padding_value` parameter will be removed in future versions. "
244
+ "Instead, you should specify `padding_value` for each column in TensorSchema"
245
+ )
233
246
  def __init__(
234
247
  self,
235
248
  sequential: SequentialDataset,
@@ -286,12 +299,12 @@ def _shift_features(
286
299
  shifted_features: MutableTensorMap = {}
287
300
  for feature_name, feature in schema.items():
288
301
  if feature.is_seq:
289
- shifted_features[feature_name] = _shift_seq(features[feature_name])
302
+ shifted_features[feature_name] = _shift_seq(features[feature_name], feature.padding_value)
290
303
  else:
291
304
  shifted_features[feature_name] = features[feature_name]
292
305
 
293
306
  # [0, 0, 1, 1, 1] -> [0, 1, 1, 1, 0]
294
- tokens_mask = _shift_seq(padding_mask)
307
+ tokens_mask = _shift_seq(padding_mask, 0)
295
308
 
296
309
  # [0, 1, 1, 1, 0] -> [0, 1, 1, 1, 1]
297
310
  shifted_padding_mask = tokens_mask.clone()
@@ -304,7 +317,7 @@ def _shift_features(
304
317
  )
305
318
 
306
319
 
307
- def _shift_seq(seq: torch.Tensor) -> torch.Tensor:
320
+ def _shift_seq(seq: torch.Tensor, padding_value: int) -> torch.Tensor:
308
321
  shifted_seq = seq.roll(-1, dims=0)
309
- shifted_seq[-1, ...] = 0
322
+ shifted_seq[-1, ...] = padding_value
310
323
  return shifted_seq