replay-rec 0.20.0rc0__py3-none-any.whl → 0.20.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/nn/sequence_tokenizer.py +10 -3
- replay/data/nn/sequential_dataset.py +18 -14
- replay/data/nn/torch_sequential_dataset.py +12 -12
- replay/models/lin_ucb.py +55 -9
- replay/models/nn/sequential/bert4rec/dataset.py +3 -16
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +3 -16
- replay/utils/__init__.py +0 -1
- {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1rc0.dist-info}/METADATA +1 -1
- {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1rc0.dist-info}/RECORD +14 -15
- replay/utils/warnings.py +0 -26
- {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1rc0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1rc0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1rc0.dist-info}/licenses/NOTICE +0 -0
replay/__init__.py
CHANGED
|
@@ -15,7 +15,6 @@ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, Feat
|
|
|
15
15
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
16
16
|
from replay.preprocessing import LabelEncoder, LabelEncodingRule
|
|
17
17
|
from replay.preprocessing.label_encoder import HandleUnknownStrategies
|
|
18
|
-
from replay.utils import deprecation_warning
|
|
19
18
|
|
|
20
19
|
if TYPE_CHECKING:
|
|
21
20
|
from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
|
|
@@ -406,7 +405,6 @@ class SequenceTokenizer:
|
|
|
406
405
|
tensor_feature._set_cardinality(dataset_feature.cardinality)
|
|
407
406
|
|
|
408
407
|
@classmethod
|
|
409
|
-
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
410
408
|
def load(cls, path: str, use_pickle: bool = False, **kwargs) -> "SequenceTokenizer":
|
|
411
409
|
"""
|
|
412
410
|
Load tokenizer object from the given path.
|
|
@@ -450,12 +448,16 @@ class SequenceTokenizer:
|
|
|
450
448
|
tokenizer._encoder._features_columns = encoder_features_columns
|
|
451
449
|
tokenizer._encoder._encoding_rules = tokenizer_dict["encoder"]["encoding_rules"]
|
|
452
450
|
else:
|
|
451
|
+
warnings.warn(
|
|
452
|
+
"with `use_pickle` equals to `True` will be deprecated in future versions",
|
|
453
|
+
DeprecationWarning,
|
|
454
|
+
stacklevel=2,
|
|
455
|
+
)
|
|
453
456
|
with open(path, "rb") as file:
|
|
454
457
|
tokenizer = pickle.load(file)
|
|
455
458
|
|
|
456
459
|
return tokenizer
|
|
457
460
|
|
|
458
|
-
@deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
|
|
459
461
|
def save(self, path: str, use_pickle: bool = False) -> None:
|
|
460
462
|
"""
|
|
461
463
|
Save the tokenizer to the given path.
|
|
@@ -496,6 +498,11 @@ class SequenceTokenizer:
|
|
|
496
498
|
with open(base_path / "init_args.json", "w+") as file:
|
|
497
499
|
json.dump(tokenizer_dict, file)
|
|
498
500
|
else:
|
|
501
|
+
warnings.warn(
|
|
502
|
+
"with `use_pickle` equals to `True` will be deprecated in future versions",
|
|
503
|
+
DeprecationWarning,
|
|
504
|
+
stacklevel=2,
|
|
505
|
+
)
|
|
499
506
|
with open(path, "wb") as file:
|
|
500
507
|
pickle.dump(self, file)
|
|
501
508
|
|
|
@@ -110,17 +110,27 @@ class SequentialDataset(abc.ABC):
|
|
|
110
110
|
|
|
111
111
|
sequential_dict = {}
|
|
112
112
|
sequential_dict["_class_name"] = self.__class__.__name__
|
|
113
|
-
|
|
113
|
+
|
|
114
|
+
df = SequentialDataset._convert_array_to_list(self._sequences)
|
|
115
|
+
df.reset_index().to_parquet(base_path / "sequences.parquet")
|
|
114
116
|
sequential_dict["init_args"] = {
|
|
115
117
|
"tensor_schema": self._tensor_schema._get_object_args(),
|
|
116
118
|
"query_id_column": self._query_id_column,
|
|
117
119
|
"item_id_column": self._item_id_column,
|
|
118
|
-
"sequences_path": "sequences.
|
|
120
|
+
"sequences_path": "sequences.parquet",
|
|
119
121
|
}
|
|
120
122
|
|
|
121
123
|
with open(base_path / "init_args.json", "w+") as file:
|
|
122
124
|
json.dump(sequential_dict, file)
|
|
123
125
|
|
|
126
|
+
@staticmethod
|
|
127
|
+
def _convert_array_to_list(df):
|
|
128
|
+
return df.map(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def _convert_list_to_array(df):
|
|
132
|
+
return df.map(lambda x: np.array(x) if isinstance(x, list) else x)
|
|
133
|
+
|
|
124
134
|
|
|
125
135
|
class PandasSequentialDataset(SequentialDataset):
|
|
126
136
|
"""
|
|
@@ -149,7 +159,7 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
149
159
|
if sequences.index.name != query_id_column:
|
|
150
160
|
sequences = sequences.set_index(query_id_column)
|
|
151
161
|
|
|
152
|
-
self._sequences = sequences
|
|
162
|
+
self._sequences = SequentialDataset._convert_list_to_array(sequences)
|
|
153
163
|
|
|
154
164
|
def __len__(self) -> int:
|
|
155
165
|
return len(self._sequences)
|
|
@@ -206,7 +216,8 @@ class PandasSequentialDataset(SequentialDataset):
|
|
|
206
216
|
with open(base_path / "init_args.json") as file:
|
|
207
217
|
sequential_dict = json.loads(file.read())
|
|
208
218
|
|
|
209
|
-
sequences = pd.
|
|
219
|
+
sequences = pd.read_parquet(base_path / sequential_dict["init_args"]["sequences_path"])
|
|
220
|
+
sequences = cls._convert_array_to_list(sequences)
|
|
210
221
|
dataset = cls(
|
|
211
222
|
tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
|
|
212
223
|
query_id_column=sequential_dict["init_args"]["query_id_column"],
|
|
@@ -258,18 +269,11 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
258
269
|
|
|
259
270
|
def _convert_polars_to_pandas(self, df: PolarsDataFrame) -> PandasDataFrame:
|
|
260
271
|
pandas_df = PandasDataFrame(df.to_dict(as_series=False))
|
|
261
|
-
|
|
262
|
-
for column in pandas_df.select_dtypes(include="object").columns:
|
|
263
|
-
if isinstance(pandas_df[column].iloc[0], list):
|
|
264
|
-
pandas_df[column] = pandas_df[column].apply(lambda x: np.array(x))
|
|
265
|
-
|
|
272
|
+
pandas_df = SequentialDataset._convert_list_to_array(pandas_df)
|
|
266
273
|
return pandas_df
|
|
267
274
|
|
|
268
275
|
def _convert_pandas_to_polars(self, df: PandasDataFrame) -> PolarsDataFrame:
|
|
269
|
-
|
|
270
|
-
if isinstance(df[column].iloc[0], np.ndarray):
|
|
271
|
-
df[column] = df[column].apply(lambda x: x.tolist())
|
|
272
|
-
|
|
276
|
+
df = SequentialDataset._convert_array_to_list(df)
|
|
273
277
|
return pl.from_dict(df.to_dict("list"))
|
|
274
278
|
|
|
275
279
|
@classmethod
|
|
@@ -290,7 +294,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
|
|
|
290
294
|
with open(base_path / "init_args.json") as file:
|
|
291
295
|
sequential_dict = json.loads(file.read())
|
|
292
296
|
|
|
293
|
-
sequences = pl.
|
|
297
|
+
sequences = pl.from_pandas(pd.read_parquet(base_path / sequential_dict["init_args"]["sequences_path"]))
|
|
294
298
|
dataset = cls(
|
|
295
299
|
tensor_schema=TensorSchema._create_object_by_args(sequential_dict["init_args"]["tensor_schema"]),
|
|
296
300
|
query_id_column=sequential_dict["init_args"]["query_id_column"],
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
from collections.abc import Generator, Sequence
|
|
2
3
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
|
|
3
4
|
|
|
@@ -5,8 +6,6 @@ import numpy as np
|
|
|
5
6
|
import torch
|
|
6
7
|
from torch.utils.data import Dataset as TorchDataset
|
|
7
8
|
|
|
8
|
-
from replay.utils import deprecation_warning
|
|
9
|
-
|
|
10
9
|
if TYPE_CHECKING:
|
|
11
10
|
from .schema import TensorFeatureInfo, TensorMap, TensorSchema
|
|
12
11
|
from .sequential_dataset import SequentialDataset
|
|
@@ -29,16 +28,12 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
29
28
|
Torch dataset for sequential recommender models
|
|
30
29
|
"""
|
|
31
30
|
|
|
32
|
-
@deprecation_warning(
|
|
33
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
34
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
35
|
-
)
|
|
36
31
|
def __init__(
|
|
37
32
|
self,
|
|
38
33
|
sequential: "SequentialDataset",
|
|
39
34
|
max_sequence_length: int,
|
|
40
35
|
sliding_window_step: Optional[int] = None,
|
|
41
|
-
padding_value: int =
|
|
36
|
+
padding_value: Optional[int] = None,
|
|
42
37
|
) -> None:
|
|
43
38
|
"""
|
|
44
39
|
:param sequential: sequential dataset
|
|
@@ -53,6 +48,15 @@ class TorchSequentialDataset(TorchDataset):
|
|
|
53
48
|
self._sequential = sequential
|
|
54
49
|
self._max_sequence_length = max_sequence_length
|
|
55
50
|
self._sliding_window_step = sliding_window_step
|
|
51
|
+
if padding_value is not None:
|
|
52
|
+
warnings.warn(
|
|
53
|
+
"`padding_value` parameter will be removed in future versions. "
|
|
54
|
+
"Instead, you should specify `padding_value` for each column in TensorSchema",
|
|
55
|
+
DeprecationWarning,
|
|
56
|
+
stacklevel=2,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
padding_value = 0
|
|
56
60
|
self._padding_value = padding_value
|
|
57
61
|
self._index2sequence_map = self._build_index2sequence_map()
|
|
58
62
|
|
|
@@ -177,17 +181,13 @@ class TorchSequentialValidationDataset(TorchDataset):
|
|
|
177
181
|
Torch dataset for sequential recommender models that additionally stores ground truth
|
|
178
182
|
"""
|
|
179
183
|
|
|
180
|
-
@deprecation_warning(
|
|
181
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
182
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
183
|
-
)
|
|
184
184
|
def __init__(
|
|
185
185
|
self,
|
|
186
186
|
sequential: "SequentialDataset",
|
|
187
187
|
ground_truth: "SequentialDataset",
|
|
188
188
|
train: "SequentialDataset",
|
|
189
189
|
max_sequence_length: int,
|
|
190
|
-
padding_value: int =
|
|
190
|
+
padding_value: Optional[int] = None,
|
|
191
191
|
sliding_window_step: Optional[int] = None,
|
|
192
192
|
label_feature_name: Optional[str] = None,
|
|
193
193
|
):
|
replay/models/lin_ucb.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from
|
|
2
|
+
from os.path import join
|
|
3
|
+
from typing import Optional, Union
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
@@ -8,7 +9,11 @@ from tqdm import tqdm
|
|
|
8
9
|
|
|
9
10
|
from replay.data.dataset import Dataset
|
|
10
11
|
from replay.utils import SparkDataFrame
|
|
11
|
-
from replay.utils.spark_utils import
|
|
12
|
+
from replay.utils.spark_utils import (
|
|
13
|
+
convert2spark,
|
|
14
|
+
load_pickled_from_parquet,
|
|
15
|
+
save_picklable_to_parquet,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
from .base_rec import HybridRecommender
|
|
14
19
|
|
|
@@ -177,6 +182,7 @@ class LinUCB(HybridRecommender):
|
|
|
177
182
|
_study = None # field required for proper optuna's optimization
|
|
178
183
|
linucb_arms: list[Union[DisjointArm, HybridArm]] # initialize only when working within fit method
|
|
179
184
|
rel_matrix: np.array # matrix with relevance scores from predict method
|
|
185
|
+
_num_items: int # number of items/arms
|
|
180
186
|
|
|
181
187
|
def __init__(
|
|
182
188
|
self,
|
|
@@ -195,7 +201,7 @@ class LinUCB(HybridRecommender):
|
|
|
195
201
|
|
|
196
202
|
@property
|
|
197
203
|
def _init_args(self):
|
|
198
|
-
return {"is_hybrid": self.is_hybrid}
|
|
204
|
+
return {"is_hybrid": self.is_hybrid, "eps": self.eps, "alpha": self.alpha}
|
|
199
205
|
|
|
200
206
|
def _verify_features(self, dataset: Dataset):
|
|
201
207
|
if dataset.query_features is None:
|
|
@@ -230,6 +236,7 @@ class LinUCB(HybridRecommender):
|
|
|
230
236
|
self._num_items = item_features.shape[0]
|
|
231
237
|
self._user_dim_size = user_features.shape[1] - 1
|
|
232
238
|
self._item_dim_size = item_features.shape[1] - 1
|
|
239
|
+
self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
|
|
233
240
|
|
|
234
241
|
# now initialize an arm object for each potential arm instance
|
|
235
242
|
if self.is_hybrid:
|
|
@@ -248,11 +255,14 @@ class LinUCB(HybridRecommender):
|
|
|
248
255
|
]
|
|
249
256
|
|
|
250
257
|
for i in tqdm(range(self._num_items)):
|
|
251
|
-
B = log.loc[
|
|
252
|
-
|
|
253
|
-
|
|
258
|
+
B = log.loc[ # noqa: N806
|
|
259
|
+
(log[feature_schema.item_id_column] == i)
|
|
260
|
+
& (log[feature_schema.query_id_column].isin(self._user_idxs_list))
|
|
261
|
+
]
|
|
254
262
|
if not B.empty:
|
|
255
263
|
# if we have at least one user interacting with the hand i
|
|
264
|
+
idxs_list = B[feature_schema.query_id_column].values
|
|
265
|
+
rel_list = B[feature_schema.interactions_rating_column].values
|
|
256
266
|
cur_usrs = scs.csr_matrix(
|
|
257
267
|
user_features.query(f"{feature_schema.query_id_column} in @idxs_list")
|
|
258
268
|
.drop(columns=[feature_schema.query_id_column])
|
|
@@ -284,11 +294,14 @@ class LinUCB(HybridRecommender):
|
|
|
284
294
|
]
|
|
285
295
|
|
|
286
296
|
for i in range(self._num_items):
|
|
287
|
-
B = log.loc[
|
|
288
|
-
|
|
289
|
-
|
|
297
|
+
B = log.loc[ # noqa: N806
|
|
298
|
+
(log[feature_schema.item_id_column] == i)
|
|
299
|
+
& (log[feature_schema.query_id_column].isin(self._user_idxs_list))
|
|
300
|
+
]
|
|
290
301
|
if not B.empty:
|
|
291
302
|
# if we have at least one user interacting with the hand i
|
|
303
|
+
idxs_list = B[feature_schema.query_id_column].values # noqa: F841
|
|
304
|
+
rel_list = B[feature_schema.interactions_rating_column].values
|
|
292
305
|
cur_usrs = user_features.query(f"{feature_schema.query_id_column} in @idxs_list").drop(
|
|
293
306
|
columns=[feature_schema.query_id_column]
|
|
294
307
|
)
|
|
@@ -318,8 +331,10 @@ class LinUCB(HybridRecommender):
|
|
|
318
331
|
user_features = dataset.query_features
|
|
319
332
|
item_features = dataset.item_features
|
|
320
333
|
big_k = min(oversample * k, item_features.shape[0])
|
|
334
|
+
self._user_idxs_list = set(user_features[feature_schema.query_id_column].values)
|
|
321
335
|
|
|
322
336
|
users = users.toPandas()
|
|
337
|
+
users = users[users[feature_schema.query_id_column].isin(self._user_idxs_list)]
|
|
323
338
|
num_user_pred = users.shape[0]
|
|
324
339
|
rel_matrix = np.zeros((num_user_pred, self._num_items), dtype=float)
|
|
325
340
|
|
|
@@ -404,3 +419,34 @@ class LinUCB(HybridRecommender):
|
|
|
404
419
|
warnings.warn(warn_msg)
|
|
405
420
|
dataset.to_spark()
|
|
406
421
|
return convert2spark(res_df)
|
|
422
|
+
|
|
423
|
+
def _save_model(self, path: str, additional_params: Optional[dict] = None):
|
|
424
|
+
super()._save_model(path, additional_params)
|
|
425
|
+
|
|
426
|
+
save_picklable_to_parquet(self.linucb_arms, join(path, "linucb_arms.dump"))
|
|
427
|
+
|
|
428
|
+
if self.is_hybrid:
|
|
429
|
+
linucb_hybrid_shared_params = {
|
|
430
|
+
"A_0": self.A_0,
|
|
431
|
+
"A_0_inv": self.A_0_inv,
|
|
432
|
+
"b_0": self.b_0,
|
|
433
|
+
"beta": self.beta,
|
|
434
|
+
}
|
|
435
|
+
save_picklable_to_parquet(
|
|
436
|
+
linucb_hybrid_shared_params,
|
|
437
|
+
join(path, "linucb_hybrid_shared_params.dump"),
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _load_model(self, path: str):
|
|
441
|
+
super()._load_model(path)
|
|
442
|
+
|
|
443
|
+
loaded_linucb_arms = load_pickled_from_parquet(join(path, "linucb_arms.dump"))
|
|
444
|
+
self.linucb_arms = loaded_linucb_arms
|
|
445
|
+
self._num_items = len(loaded_linucb_arms)
|
|
446
|
+
|
|
447
|
+
if self.is_hybrid:
|
|
448
|
+
loaded_linucb_hybrid_shared_params = load_pickled_from_parquet(
|
|
449
|
+
join(path, "linucb_hybrid_shared_params.dump")
|
|
450
|
+
)
|
|
451
|
+
for param, value in loaded_linucb_hybrid_shared_params.items():
|
|
452
|
+
setattr(self, param, value)
|
|
@@ -12,7 +12,6 @@ from replay.data.nn import (
|
|
|
12
12
|
TorchSequentialDataset,
|
|
13
13
|
TorchSequentialValidationDataset,
|
|
14
14
|
)
|
|
15
|
-
from replay.utils import deprecation_warning
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
class Bert4RecTrainingBatch(NamedTuple):
|
|
@@ -89,10 +88,6 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
89
88
|
Dataset that generates samples to train BERT-like model
|
|
90
89
|
"""
|
|
91
90
|
|
|
92
|
-
@deprecation_warning(
|
|
93
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
94
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
95
|
-
)
|
|
96
91
|
def __init__(
|
|
97
92
|
self,
|
|
98
93
|
sequential: SequentialDataset,
|
|
@@ -101,7 +96,7 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
101
96
|
sliding_window_step: Optional[int] = None,
|
|
102
97
|
label_feature_name: Optional[str] = None,
|
|
103
98
|
custom_masker: Optional[Bert4RecMasker] = None,
|
|
104
|
-
padding_value: int =
|
|
99
|
+
padding_value: Optional[int] = None,
|
|
105
100
|
) -> None:
|
|
106
101
|
"""
|
|
107
102
|
:param sequential: Sequential dataset with training data.
|
|
@@ -181,15 +176,11 @@ class Bert4RecPredictionDataset(TorchDataset):
|
|
|
181
176
|
Dataset that generates samples to infer BERT-like model
|
|
182
177
|
"""
|
|
183
178
|
|
|
184
|
-
@deprecation_warning(
|
|
185
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
186
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
187
|
-
)
|
|
188
179
|
def __init__(
|
|
189
180
|
self,
|
|
190
181
|
sequential: SequentialDataset,
|
|
191
182
|
max_sequence_length: int,
|
|
192
|
-
padding_value: int =
|
|
183
|
+
padding_value: Optional[int] = None,
|
|
193
184
|
) -> None:
|
|
194
185
|
"""
|
|
195
186
|
:param sequential: Sequential dataset with data to make predictions at.
|
|
@@ -239,17 +230,13 @@ class Bert4RecValidationDataset(TorchDataset):
|
|
|
239
230
|
Dataset that generates samples to infer and validate BERT-like model
|
|
240
231
|
"""
|
|
241
232
|
|
|
242
|
-
@deprecation_warning(
|
|
243
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
244
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
245
|
-
)
|
|
246
233
|
def __init__(
|
|
247
234
|
self,
|
|
248
235
|
sequential: SequentialDataset,
|
|
249
236
|
ground_truth: SequentialDataset,
|
|
250
237
|
train: SequentialDataset,
|
|
251
238
|
max_sequence_length: int,
|
|
252
|
-
padding_value: int =
|
|
239
|
+
padding_value: Optional[int] = None,
|
|
253
240
|
label_feature_name: Optional[str] = None,
|
|
254
241
|
):
|
|
255
242
|
"""
|
|
@@ -51,7 +51,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
51
51
|
|
|
52
52
|
def _compute_scores(self, query_ids: torch.LongTensor, scores: torch.Tensor) -> torch.Tensor:
|
|
53
53
|
flat_seen_item_ids = self._get_flat_seen_item_ids(query_ids)
|
|
54
|
-
return self._fill_item_ids(scores, flat_seen_item_ids, -np.inf)
|
|
54
|
+
return self._fill_item_ids(scores.clone(), flat_seen_item_ids, -np.inf)
|
|
55
55
|
|
|
56
56
|
def _fill_item_ids(
|
|
57
57
|
self,
|
|
@@ -10,7 +10,6 @@ from replay.data.nn import (
|
|
|
10
10
|
TorchSequentialDataset,
|
|
11
11
|
TorchSequentialValidationDataset,
|
|
12
12
|
)
|
|
13
|
-
from replay.utils import deprecation_warning
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class SasRecTrainingBatch(NamedTuple):
|
|
@@ -31,17 +30,13 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
31
30
|
Dataset that generates samples to train SasRec-like model
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
|
-
@deprecation_warning(
|
|
35
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
36
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
37
|
-
)
|
|
38
33
|
def __init__(
|
|
39
34
|
self,
|
|
40
35
|
sequential: SequentialDataset,
|
|
41
36
|
max_sequence_length: int,
|
|
42
37
|
sequence_shift: int = 1,
|
|
43
38
|
sliding_window_step: Optional[None] = None,
|
|
44
|
-
padding_value: int =
|
|
39
|
+
padding_value: Optional[int] = None,
|
|
45
40
|
label_feature_name: Optional[str] = None,
|
|
46
41
|
) -> None:
|
|
47
42
|
"""
|
|
@@ -127,15 +122,11 @@ class SasRecPredictionDataset(TorchDataset):
|
|
|
127
122
|
Dataset that generates samples to infer SasRec-like model
|
|
128
123
|
"""
|
|
129
124
|
|
|
130
|
-
@deprecation_warning(
|
|
131
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
132
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
133
|
-
)
|
|
134
125
|
def __init__(
|
|
135
126
|
self,
|
|
136
127
|
sequential: SequentialDataset,
|
|
137
128
|
max_sequence_length: int,
|
|
138
|
-
padding_value: int =
|
|
129
|
+
padding_value: Optional[int] = None,
|
|
139
130
|
) -> None:
|
|
140
131
|
"""
|
|
141
132
|
:param sequential: Sequential dataset with data to make predictions at.
|
|
@@ -179,17 +170,13 @@ class SasRecValidationDataset(TorchDataset):
|
|
|
179
170
|
Dataset that generates samples to infer and validate SasRec-like model
|
|
180
171
|
"""
|
|
181
172
|
|
|
182
|
-
@deprecation_warning(
|
|
183
|
-
"`padding_value` parameter will be removed in future versions. "
|
|
184
|
-
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
185
|
-
)
|
|
186
173
|
def __init__(
|
|
187
174
|
self,
|
|
188
175
|
sequential: SequentialDataset,
|
|
189
176
|
ground_truth: SequentialDataset,
|
|
190
177
|
train: SequentialDataset,
|
|
191
178
|
max_sequence_length: int,
|
|
192
|
-
padding_value: int =
|
|
179
|
+
padding_value: Optional[int] = None,
|
|
193
180
|
label_feature_name: Optional[str] = None,
|
|
194
181
|
):
|
|
195
182
|
"""
|
replay/utils/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
replay/__init__.py,sha256=
|
|
1
|
+
replay/__init__.py,sha256=4xb9FHSuRPA_dhFTY5XvoJ7s_epCHAcBMiRPwORT_gQ,233
|
|
2
2
|
replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
|
|
3
3
|
replay/data/dataset.py,sha256=yQDc8lfphQYfHpm_T1MhnG8_GyM4ONyxJoFc1rUgdJ8,30755
|
|
4
4
|
replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
|
|
5
5
|
replay/data/dataset_utils/dataset_label_encoder.py,sha256=bxuJPhShFZBok7bQZYGNMV1etCLNTJUpyKO5MIwWack,9823
|
|
6
6
|
replay/data/nn/__init__.py,sha256=nj2Ep-tduuQkc-TnBkvN8-rDnFbcWO2oZrfcXl9M3C8,1122
|
|
7
7
|
replay/data/nn/schema.py,sha256=h1KgaNV-hgN9Vpt24c92EmeMpm_8W0s9a2M0wLxJHYk,17101
|
|
8
|
-
replay/data/nn/sequence_tokenizer.py,sha256=
|
|
9
|
-
replay/data/nn/sequential_dataset.py,sha256=
|
|
10
|
-
replay/data/nn/torch_sequential_dataset.py,sha256=
|
|
8
|
+
replay/data/nn/sequence_tokenizer.py,sha256=_9fBF-84jdn8Pa3pFKIr6prUjNYCc6BVzwRl9VSleKQ,37419
|
|
9
|
+
replay/data/nn/sequential_dataset.py,sha256=qthp87SQ44VpgoH3RKsqm6CxCeQyApn58l7_16txAZM,11303
|
|
10
|
+
replay/data/nn/torch_sequential_dataset.py,sha256=QSh4IM2vzAF095_ZMC1gMqZj9slHXos9gfx_R_DlpGM,11545
|
|
11
11
|
replay/data/nn/utils.py,sha256=Ic3G4yZRIzBYXLmwP1VstlZXPNR7AYGCc5EyZAERp5c,3297
|
|
12
12
|
replay/data/schema.py,sha256=JmYLCrNgBS5oq4O_PT724Gr1pDurHEykcqV8Xaj0XTw,15922
|
|
13
13
|
replay/data/spark_schema.py,sha256=4o0Kn_fjwz2-9dBY3q46F9PL0F3E7jdVpIlX7SG3OZI,1111
|
|
@@ -122,7 +122,7 @@ replay/models/extensions/ann/index_stores/utils.py,sha256=6r2GP_EFCaCguolW857pb4
|
|
|
122
122
|
replay/models/extensions/ann/utils.py,sha256=AgQvThi_DvEtakQeTno9hVZVWiWMFHKTjRcQ2wLa5vk,1222
|
|
123
123
|
replay/models/kl_ucb.py,sha256=L6vC2KsTBTTx4ckmGhWybOiLa5Wt54N7cgl7jS2FQRg,6731
|
|
124
124
|
replay/models/knn.py,sha256=HEiGHHQg9pV1_EIWZHfK-XD0BNAm1bj1c0ND9rYnj3k,8992
|
|
125
|
-
replay/models/lin_ucb.py,sha256=
|
|
125
|
+
replay/models/lin_ucb.py,sha256=iAR3PbbaQKqmisOKEx9ZyfpxnxcZomr6YauG4mvSakU,18800
|
|
126
126
|
replay/models/nn/__init__.py,sha256=AT3o1qXaxUq4_QIGlcGuSs54ZpueOo-SbpZwuGI-6os,41
|
|
127
127
|
replay/models/nn/loss/__init__.py,sha256=s3iO9QTZvLz_ony2b5K0hEmDmitrXQnAe9j6BRxLpR4,53
|
|
128
128
|
replay/models/nn/loss/sce.py,sha256=p6LFtoYSY4j2pQh6Z7i6cEADCmRnvTgnb8EJXseRKKg,5637
|
|
@@ -130,7 +130,7 @@ replay/models/nn/optimizer_utils/__init__.py,sha256=8MHln7CW54oACVUFKdZLjAf4bY83
|
|
|
130
130
|
replay/models/nn/optimizer_utils/optimizer_factory.py,sha256=1wicKnya2xrwDaHhqygy1VqB8-3jPDhMM7zY2TJE4dY,2844
|
|
131
131
|
replay/models/nn/sequential/__init__.py,sha256=CI2n0cxs_amqJrwBMq6n0Z_uBOu7CGXfagqvE4Jlmjw,128
|
|
132
132
|
replay/models/nn/sequential/bert4rec/__init__.py,sha256=JfZqHOGxcvOkICl5cWmZbZhaKXpkIvua-Wj57VWWEhw,399
|
|
133
|
-
replay/models/nn/sequential/bert4rec/dataset.py,sha256=
|
|
133
|
+
replay/models/nn/sequential/bert4rec/dataset.py,sha256=xd5a-yn5I280Vwoy_KtasDjrvksFolJYp71nDEHNUNQ,10414
|
|
134
134
|
replay/models/nn/sequential/bert4rec/lightning.py,sha256=_hP6_6E1SpGu6b_kiYEF4ZVhwKJ4sj_iPTo6loIvM0o,26546
|
|
135
135
|
replay/models/nn/sequential/bert4rec/model.py,sha256=2Lqvfz7UBB_ArqNs92OD5dy4a1onR4S5dNZiMbZgAgk,17388
|
|
136
136
|
replay/models/nn/sequential/callbacks/__init__.py,sha256=Q7mSZ_RB6iyD7QZaBL_NJ0uh8cRfgxq7gtPHbkSyhoo,282
|
|
@@ -142,9 +142,9 @@ replay/models/nn/sequential/compiled/bert4rec_compiled.py,sha256=Z6nfmdT70Wi-j7_
|
|
|
142
142
|
replay/models/nn/sequential/compiled/sasrec_compiled.py,sha256=qUaAwQOsBCstOG3RBlj_pJpD8BHmCpLZWCiPBlFVvT4,5856
|
|
143
143
|
replay/models/nn/sequential/postprocessors/__init__.py,sha256=89LGzkNHukcuC2-rfpiz7vmv1zyk6MNY-8zaXrvtn0M,164
|
|
144
144
|
replay/models/nn/sequential/postprocessors/_base.py,sha256=Q_SIYKG8G3U03IEK1dtlW1zJI300pOcWQYuMpkY0_nc,1111
|
|
145
|
-
replay/models/nn/sequential/postprocessors/postprocessors.py,sha256=
|
|
145
|
+
replay/models/nn/sequential/postprocessors/postprocessors.py,sha256=oijLByxuzegVmWZS-qRVhdO7ihqHer6SSGTFa8zX7I8,7810
|
|
146
146
|
replay/models/nn/sequential/sasrec/__init__.py,sha256=c6130lRpPkcbuGgkM7slagBIgH7Uk5zUtSzFDEwAsik,250
|
|
147
|
-
replay/models/nn/sequential/sasrec/dataset.py,sha256=
|
|
147
|
+
replay/models/nn/sequential/sasrec/dataset.py,sha256=L_LeRWqPc__390j8NWVskboS0NqbveIkLwFclcB4oDw,7189
|
|
148
148
|
replay/models/nn/sequential/sasrec/lightning.py,sha256=oScUyB8RU8N4MqWe6kAoWG0JW6Tkb2ldG_jdGFZgA7A,25060
|
|
149
149
|
replay/models/nn/sequential/sasrec/model.py,sha256=8kFovyPWqgQ0hmD3gckRjW7-hLBerl3bgYXCk4PYn0o,27656
|
|
150
150
|
replay/models/optimization/__init__.py,sha256=N8xCuzu0jQGwHrIBjuTRf-ZcZuBJ6FB0d9C5a7izJQU,338
|
|
@@ -177,7 +177,7 @@ replay/splitters/random_splitter.py,sha256=0DO0qulT0jp_GXswmFh3BMJ7utS-z9e-r5jIr
|
|
|
177
177
|
replay/splitters/ratio_splitter.py,sha256=rFWN-nKBYx1qKrmtYzjYf08DWFiKOCo5ZRUz-NHJFfs,17506
|
|
178
178
|
replay/splitters/time_splitter.py,sha256=0ZAMK26b--1wjrfzCuNVBh7gMPTa8SGf4LMEgACiUxA,9013
|
|
179
179
|
replay/splitters/two_stage_splitter.py,sha256=8Zn6BTJmZg04CD4l2jmil2dEu6xtglJaSS5mkotIXRc,17823
|
|
180
|
-
replay/utils/__init__.py,sha256=
|
|
180
|
+
replay/utils/__init__.py,sha256=3Skc9bUISEPPMMxdUCCT_S1q-i7cAT3KT0nExe-VMrw,343
|
|
181
181
|
replay/utils/common.py,sha256=92MTG51WpeEQJ2gu-WvdNe4Fmqm8ze-y1VNIAHW81jQ,5358
|
|
182
182
|
replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
|
|
183
183
|
replay/utils/distributions.py,sha256=UuhaC9HI6HnUXW97fEd-TsyDk4JT8t7k1T_6l5FpOMs,1203
|
|
@@ -186,9 +186,8 @@ replay/utils/session_handler.py,sha256=fQo2wseow8yuzKnEXT-aYAXcQIgRbTTXp0v7g1VVi
|
|
|
186
186
|
replay/utils/spark_utils.py,sha256=GbRp-MuUoO3Pc4chFvlmo9FskSlRLeNlC3Go5pEJ6Ok,27411
|
|
187
187
|
replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
|
|
188
188
|
replay/utils/types.py,sha256=rD9q9CqEXgF4yy512Hv2nXclvwcnfodOnhBZ1HSUI4c,1260
|
|
189
|
-
|
|
190
|
-
replay_rec-0.20.
|
|
191
|
-
replay_rec-0.20.
|
|
192
|
-
replay_rec-0.20.
|
|
193
|
-
replay_rec-0.20.
|
|
194
|
-
replay_rec-0.20.0rc0.dist-info/RECORD,,
|
|
189
|
+
replay_rec-0.20.1rc0.dist-info/METADATA,sha256=5QtFQnGuoWpCTBMgKX0q9O1tPdBqmTLTOcmzHP0VkNo,13155
|
|
190
|
+
replay_rec-0.20.1rc0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
191
|
+
replay_rec-0.20.1rc0.dist-info/licenses/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
|
|
192
|
+
replay_rec-0.20.1rc0.dist-info/licenses/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
|
|
193
|
+
replay_rec-0.20.1rc0.dist-info/RECORD,,
|
replay/utils/warnings.py
DELETED
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
import warnings
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from typing import Any, Optional
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
|
|
8
|
-
"""
|
|
9
|
-
Decorator that throws deprecation warnings.
|
|
10
|
-
|
|
11
|
-
:param message: message to deprecation warning without func name.
|
|
12
|
-
"""
|
|
13
|
-
base_msg = "will be deprecated in future versions."
|
|
14
|
-
|
|
15
|
-
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
16
|
-
@functools.wraps(func)
|
|
17
|
-
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
18
|
-
msg = f"{func.__qualname__} {message if message else base_msg}"
|
|
19
|
-
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
|
20
|
-
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
|
|
21
|
-
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
|
22
|
-
return func(*args, **kwargs)
|
|
23
|
-
|
|
24
|
-
return wrapper
|
|
25
|
-
|
|
26
|
-
return decorator
|
|
File without changes
|
|
File without changes
|
|
File without changes
|