replay-rec 0.20.3rc0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/METADATA +18 -12
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- replay_rec-0.20.3rc0.dist-info/RECORD +0 -193
- /replay/{experimental → data/nn/parquet/constants}/__init__.py +0 -0
- /replay/{experimental/models/dt4rec → data/nn/parquet/impl}/__init__.py +0 -0
- /replay/{experimental/models/extensions/spark_custom_models → data/nn/parquet/info}/__init__.py +0 -0
- /replay/{experimental/scenarios/two_stages → data/nn/parquet/utils}/__init__.py +0 -0
- /replay/{experimental → data}/utils/__init__.py +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3rc0.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -17,7 +17,7 @@ from replay.data.nn import (
|
|
|
17
17
|
class Bert4RecTrainingBatch(NamedTuple):
|
|
18
18
|
"""
|
|
19
19
|
Batch of data for training.
|
|
20
|
-
Generated by
|
|
20
|
+
Generated by ``Bert4RecTrainingDataset``.
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
query_id: torch.LongTensor
|
|
@@ -26,6 +26,15 @@ class Bert4RecTrainingBatch(NamedTuple):
|
|
|
26
26
|
tokens_mask: torch.BoolTensor
|
|
27
27
|
labels: torch.LongTensor
|
|
28
28
|
|
|
29
|
+
def convert_to_dict(self) -> dict:
|
|
30
|
+
return {
|
|
31
|
+
"query_id": self.query_id,
|
|
32
|
+
"pad_mask": self.padding_mask,
|
|
33
|
+
"inputs": self.features,
|
|
34
|
+
"token_mask": self.tokens_mask,
|
|
35
|
+
"positive_labels": self.labels,
|
|
36
|
+
}
|
|
37
|
+
|
|
29
38
|
|
|
30
39
|
class Bert4RecMasker(abc.ABC):
|
|
31
40
|
"""
|
|
@@ -85,7 +94,12 @@ class Bert4RecUniformMasker(Bert4RecMasker):
|
|
|
85
94
|
|
|
86
95
|
class Bert4RecTrainingDataset(TorchDataset):
|
|
87
96
|
"""
|
|
88
|
-
Dataset that generates samples to train
|
|
97
|
+
Dataset that generates samples to train Bert4Rec model.
|
|
98
|
+
|
|
99
|
+
As a result of the dataset iteration, a dictionary is formed.
|
|
100
|
+
The keys in the dictionary match the names of the arguments in the model's `forward` function.
|
|
101
|
+
There are also additional keys needed to calculate losses - 'positive_labels`.
|
|
102
|
+
The `query_id` key is required for possible debugging and calling additional lightning callbacks.
|
|
89
103
|
"""
|
|
90
104
|
|
|
91
105
|
def __init__(
|
|
@@ -143,26 +157,26 @@ class Bert4RecTrainingDataset(TorchDataset):
|
|
|
143
157
|
def __len__(self) -> int:
|
|
144
158
|
return len(self._inner)
|
|
145
159
|
|
|
146
|
-
def __getitem__(self, index: int) ->
|
|
160
|
+
def __getitem__(self, index: int) -> dict:
|
|
147
161
|
query_id, padding_mask, features = self._inner[index]
|
|
148
162
|
tokens_mask = self._masker.mask(padding_mask)
|
|
149
163
|
|
|
150
164
|
assert self._label_feature_name
|
|
151
165
|
labels = features[self._label_feature_name]
|
|
152
166
|
|
|
153
|
-
return
|
|
154
|
-
query_id
|
|
155
|
-
padding_mask
|
|
156
|
-
features
|
|
157
|
-
tokens_mask
|
|
158
|
-
|
|
159
|
-
|
|
167
|
+
return {
|
|
168
|
+
"query_id": query_id,
|
|
169
|
+
"pad_mask": padding_mask,
|
|
170
|
+
"inputs": features,
|
|
171
|
+
"token_mask": tokens_mask,
|
|
172
|
+
"positive_labels": labels,
|
|
173
|
+
}
|
|
160
174
|
|
|
161
175
|
|
|
162
176
|
class Bert4RecPredictionBatch(NamedTuple):
|
|
163
177
|
"""
|
|
164
178
|
Batch of data for model inference.
|
|
165
|
-
Generated by
|
|
179
|
+
Generated by ``Bert4RecPredictionDataset``.
|
|
166
180
|
"""
|
|
167
181
|
|
|
168
182
|
query_id: torch.LongTensor
|
|
@@ -170,10 +184,22 @@ class Bert4RecPredictionBatch(NamedTuple):
|
|
|
170
184
|
features: TensorMap
|
|
171
185
|
tokens_mask: torch.BoolTensor
|
|
172
186
|
|
|
187
|
+
def convert_to_dict(self) -> dict:
|
|
188
|
+
return {
|
|
189
|
+
"query_id": self.query_id,
|
|
190
|
+
"pad_mask": self.padding_mask,
|
|
191
|
+
"inputs": self.features,
|
|
192
|
+
"token_mask": self.tokens_mask,
|
|
193
|
+
}
|
|
194
|
+
|
|
173
195
|
|
|
174
196
|
class Bert4RecPredictionDataset(TorchDataset):
|
|
175
197
|
"""
|
|
176
|
-
Dataset that generates samples to
|
|
198
|
+
Dataset that generates samples to inference Bert4Rec model
|
|
199
|
+
|
|
200
|
+
As a result of the dataset iteration, a dictionary is formed.
|
|
201
|
+
The keys in the dictionary match the names of the arguments in the model's `forward` function.
|
|
202
|
+
The `query_id` key is required for possible debugging and calling additional lightning callbacks.
|
|
177
203
|
"""
|
|
178
204
|
|
|
179
205
|
def __init__(
|
|
@@ -198,23 +224,23 @@ class Bert4RecPredictionDataset(TorchDataset):
|
|
|
198
224
|
def __len__(self) -> int:
|
|
199
225
|
return len(self._inner)
|
|
200
226
|
|
|
201
|
-
def __getitem__(self, index: int) ->
|
|
227
|
+
def __getitem__(self, index: int) -> dict:
|
|
202
228
|
query_id, padding_mask, features = self._inner[index]
|
|
203
229
|
|
|
204
230
|
shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
|
|
205
231
|
|
|
206
|
-
return
|
|
207
|
-
query_id
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
tokens_mask
|
|
211
|
-
|
|
232
|
+
return {
|
|
233
|
+
"query_id": query_id,
|
|
234
|
+
"pad_mask": shifted_padding_mask,
|
|
235
|
+
"inputs": shifted_features,
|
|
236
|
+
"token_mask": tokens_mask,
|
|
237
|
+
}
|
|
212
238
|
|
|
213
239
|
|
|
214
240
|
class Bert4RecValidationBatch(NamedTuple):
|
|
215
241
|
"""
|
|
216
242
|
Batch of data for validation.
|
|
217
|
-
Generated by
|
|
243
|
+
Generated by ``Bert4RecValidationDataset``.
|
|
218
244
|
"""
|
|
219
245
|
|
|
220
246
|
query_id: torch.LongTensor
|
|
@@ -224,10 +250,25 @@ class Bert4RecValidationBatch(NamedTuple):
|
|
|
224
250
|
ground_truth: torch.LongTensor
|
|
225
251
|
train: torch.LongTensor
|
|
226
252
|
|
|
253
|
+
def convert_to_dict(self) -> dict:
|
|
254
|
+
return {
|
|
255
|
+
"query_id": self.query_id,
|
|
256
|
+
"pad_mask": self.padding_mask,
|
|
257
|
+
"inputs": self.features,
|
|
258
|
+
"token_mask": self.tokens_mask,
|
|
259
|
+
"ground_truth": self.ground_truth,
|
|
260
|
+
"train": self.train,
|
|
261
|
+
}
|
|
262
|
+
|
|
227
263
|
|
|
228
264
|
class Bert4RecValidationDataset(TorchDataset):
|
|
229
265
|
"""
|
|
230
266
|
Dataset that generates samples to infer and validate BERT-like model
|
|
267
|
+
|
|
268
|
+
As a result of the dataset iteration, a dictionary is formed.
|
|
269
|
+
The keys in the dictionary match the names of the arguments in the model's `forward` function.
|
|
270
|
+
The `query_id` key is required for possible debugging and calling additional lightning callbacks.
|
|
271
|
+
Keys 'ground_truth` and `train` keys are required for metrics calculation on validation stage.
|
|
231
272
|
"""
|
|
232
273
|
|
|
233
274
|
def __init__(
|
|
@@ -263,19 +304,19 @@ class Bert4RecValidationDataset(TorchDataset):
|
|
|
263
304
|
def __len__(self) -> int:
|
|
264
305
|
return len(self._inner)
|
|
265
306
|
|
|
266
|
-
def __getitem__(self, index: int) ->
|
|
307
|
+
def __getitem__(self, index: int) -> dict:
|
|
267
308
|
query_id, padding_mask, features, ground_truth, train = self._inner[index]
|
|
268
309
|
|
|
269
310
|
shifted_features, shifted_padding_mask, tokens_mask = _shift_features(self._schema, features, padding_mask)
|
|
270
311
|
|
|
271
|
-
return
|
|
272
|
-
query_id
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
tokens_mask
|
|
276
|
-
ground_truth
|
|
277
|
-
train
|
|
278
|
-
|
|
312
|
+
return {
|
|
313
|
+
"query_id": query_id,
|
|
314
|
+
"pad_mask": shifted_padding_mask,
|
|
315
|
+
"inputs": shifted_features,
|
|
316
|
+
"token_mask": tokens_mask,
|
|
317
|
+
"ground_truth": ground_truth,
|
|
318
|
+
"train": train,
|
|
319
|
+
}
|
|
279
320
|
|
|
280
321
|
|
|
281
322
|
def _shift_features(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import Any, Literal, Optional, Union, cast
|
|
3
4
|
|
|
4
5
|
import lightning
|
|
@@ -29,13 +30,13 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
29
30
|
enable_embedding_tying: bool = False,
|
|
30
31
|
loss_type: Literal["BCE", "CE", "CE_restricted"] = "CE",
|
|
31
32
|
loss_sample_count: Optional[int] = None,
|
|
32
|
-
negative_sampling_strategy:
|
|
33
|
+
negative_sampling_strategy: Literal["global_uniform", "inbatch"] = "global_uniform",
|
|
33
34
|
negatives_sharing: bool = False,
|
|
34
35
|
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
|
|
35
36
|
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
|
|
36
37
|
):
|
|
37
38
|
"""
|
|
38
|
-
:param tensor_schema
|
|
39
|
+
:param tensor_schema: Tensor schema of features.
|
|
39
40
|
:param block_count: Number of Transformer blocks.
|
|
40
41
|
Default: ``2``.
|
|
41
42
|
:param head_count: Number of Attention heads.
|
|
@@ -44,7 +45,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
44
45
|
Default: ``256``.
|
|
45
46
|
:param max_seq_len: Max length of sequence.
|
|
46
47
|
Default: ``100``.
|
|
47
|
-
:param dropout_rate
|
|
48
|
+
:param dropout_rate: Dropout rate.
|
|
48
49
|
Default: ``0.1``.
|
|
49
50
|
:param pass_per_transformer_block_count: Number of times to pass data over each Transformer block.
|
|
50
51
|
Default: ``1``.
|
|
@@ -54,19 +55,18 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
54
55
|
If `True` - result scores are calculated by dot product of input and output embeddings,
|
|
55
56
|
if `False` - default linear layer is applied to calculate logits for each item.
|
|
56
57
|
Default: ``False``.
|
|
57
|
-
:param loss_type: Loss type.
|
|
58
|
+
:param loss_type: Loss type.
|
|
58
59
|
Default: ``CE``.
|
|
59
|
-
:param loss_sample_count
|
|
60
|
+
:param loss_sample_count: Sample count to calculate loss.
|
|
60
61
|
Default: ``None``.
|
|
61
62
|
:param negative_sampling_strategy: Negative sampling strategy to calculate loss on sampled negatives.
|
|
62
|
-
Is used when large count of items in dataset
|
|
63
|
-
Possible values: ``"global_uniform"``, ``"inbatch"``
|
|
63
|
+
Is used when large count of items in dataset.\n
|
|
64
64
|
Default: ``global_uniform``.
|
|
65
|
-
:param negatives_sharing: Apply negative sharing in calculating sampled logits
|
|
65
|
+
:param negatives_sharing: Apply negative sharing in calculating sampled logits.\n
|
|
66
66
|
Default: ``False``.
|
|
67
|
-
:param optimizer_factory: Optimizer factory
|
|
67
|
+
:param optimizer_factory: Optimizer factory.\n
|
|
68
68
|
Default: ``FatOptimizerFactory``.
|
|
69
|
-
:param lr_scheduler_factory: Learning rate schedule factory
|
|
69
|
+
:param lr_scheduler_factory: Learning rate schedule factory.\n
|
|
70
70
|
Default: ``None``.
|
|
71
71
|
"""
|
|
72
72
|
super().__init__()
|
|
@@ -97,7 +97,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
97
97
|
self._vocab_size = item_count
|
|
98
98
|
self.candidates_to_score = None
|
|
99
99
|
|
|
100
|
-
def training_step(self, batch: Bert4RecTrainingBatch, batch_idx: int) -> torch.Tensor: # noqa: ARG002
|
|
100
|
+
def training_step(self, batch: Union[Bert4RecTrainingBatch, dict], batch_idx: int) -> torch.Tensor: # noqa: ARG002
|
|
101
101
|
"""
|
|
102
102
|
:param batch: Batch of training data.
|
|
103
103
|
:param batch_idx: Batch index.
|
|
@@ -109,7 +109,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
109
109
|
return loss
|
|
110
110
|
|
|
111
111
|
def predict_step(
|
|
112
|
-
self, batch: Bert4RecPredictionBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
112
|
+
self, batch: Union[Bert4RecPredictionBatch, dict], batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
113
113
|
) -> torch.Tensor:
|
|
114
114
|
"""
|
|
115
115
|
:param batch (Bert4RecPredictionBatch): Batch of prediction data.
|
|
@@ -118,23 +118,49 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
118
118
|
|
|
119
119
|
:returns: Calculated scores on prediction batch.
|
|
120
120
|
"""
|
|
121
|
+
if isinstance(batch, Bert4RecPredictionBatch):
|
|
122
|
+
warnings.warn(
|
|
123
|
+
"`Bert4RecPredictionBatch` class will be removed in future versions. "
|
|
124
|
+
"Instead, you should use simple dictionary",
|
|
125
|
+
DeprecationWarning,
|
|
126
|
+
stacklevel=2,
|
|
127
|
+
)
|
|
128
|
+
batch = batch.convert_to_dict()
|
|
121
129
|
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
122
|
-
return self._model_predict(
|
|
130
|
+
return self._model_predict(
|
|
131
|
+
feature_tensors=batch["inputs"],
|
|
132
|
+
padding_mask=batch["pad_mask"],
|
|
133
|
+
tokens_mask=batch["token_mask"],
|
|
134
|
+
)
|
|
123
135
|
|
|
124
136
|
def predict(
|
|
125
137
|
self,
|
|
126
|
-
batch: Bert4RecPredictionBatch,
|
|
138
|
+
batch: Union[Bert4RecPredictionBatch, dict],
|
|
127
139
|
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
128
140
|
) -> torch.Tensor:
|
|
129
141
|
"""
|
|
130
|
-
:param batch
|
|
142
|
+
:param batch: Batch of prediction data.
|
|
131
143
|
:param candidates_to_score: Item ids to calculate scores.
|
|
132
144
|
Default: ``None``.
|
|
133
145
|
|
|
134
146
|
:returns: Calculated scores on prediction batch.
|
|
135
147
|
"""
|
|
148
|
+
if isinstance(batch, Bert4RecPredictionBatch):
|
|
149
|
+
warnings.warn(
|
|
150
|
+
"`Bert4RecPredictionBatch` class will be removed in future versions. "
|
|
151
|
+
"Instead, you should use simple dictionary",
|
|
152
|
+
DeprecationWarning,
|
|
153
|
+
stacklevel=2,
|
|
154
|
+
)
|
|
155
|
+
batch = batch.convert_to_dict()
|
|
156
|
+
|
|
136
157
|
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
137
|
-
return self._model_predict(
|
|
158
|
+
return self._model_predict(
|
|
159
|
+
feature_tensors=batch["inputs"],
|
|
160
|
+
padding_mask=batch["pad_mask"],
|
|
161
|
+
tokens_mask=batch["token_mask"],
|
|
162
|
+
candidates_to_score=candidates_to_score,
|
|
163
|
+
)
|
|
138
164
|
|
|
139
165
|
def forward(
|
|
140
166
|
self,
|
|
@@ -152,10 +178,15 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
152
178
|
|
|
153
179
|
:returns: Calculated scores.
|
|
154
180
|
"""
|
|
155
|
-
return self._model_predict(
|
|
181
|
+
return self._model_predict(
|
|
182
|
+
feature_tensors=feature_tensors,
|
|
183
|
+
padding_mask=padding_mask,
|
|
184
|
+
tokens_mask=tokens_mask,
|
|
185
|
+
candidates_to_score=candidates_to_score,
|
|
186
|
+
)
|
|
156
187
|
|
|
157
188
|
def validation_step(
|
|
158
|
-
self, batch: Bert4RecValidationBatch, batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
189
|
+
self, batch: Union[Bert4RecValidationBatch, dict], batch_idx: int, dataloader_idx: int = 0 # noqa: ARG002
|
|
159
190
|
) -> torch.Tensor:
|
|
160
191
|
"""
|
|
161
192
|
:param batch: Batch of prediction data.
|
|
@@ -163,7 +194,20 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
163
194
|
|
|
164
195
|
:returns: Calculated scores on validation batch.
|
|
165
196
|
"""
|
|
166
|
-
|
|
197
|
+
if isinstance(batch, Bert4RecValidationBatch):
|
|
198
|
+
warnings.warn(
|
|
199
|
+
"`Bert4RecValidationBatch` class will be removed in future versions. "
|
|
200
|
+
"Instead, you should use simple dictionary",
|
|
201
|
+
DeprecationWarning,
|
|
202
|
+
stacklevel=2,
|
|
203
|
+
)
|
|
204
|
+
batch = batch.convert_to_dict()
|
|
205
|
+
|
|
206
|
+
return self._model_predict(
|
|
207
|
+
feature_tensors=batch["inputs"],
|
|
208
|
+
padding_mask=batch["pad_mask"],
|
|
209
|
+
tokens_mask=batch["token_mask"],
|
|
210
|
+
)
|
|
167
211
|
|
|
168
212
|
def configure_optimizers(self) -> Any:
|
|
169
213
|
"""
|
|
@@ -189,10 +233,15 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
189
233
|
cast(Bert4RecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
|
|
190
234
|
)
|
|
191
235
|
candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
|
|
192
|
-
scores = model.predict(
|
|
236
|
+
scores = model.predict(
|
|
237
|
+
inputs=feature_tensors,
|
|
238
|
+
pad_mask=padding_mask,
|
|
239
|
+
token_mask=tokens_mask,
|
|
240
|
+
candidates_to_score=candidates_to_score,
|
|
241
|
+
)
|
|
193
242
|
return scores
|
|
194
243
|
|
|
195
|
-
def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
|
|
244
|
+
def _compute_loss(self, batch: Union[Bert4RecTrainingBatch, dict]) -> torch.Tensor:
|
|
196
245
|
if self._loss_type == "BCE":
|
|
197
246
|
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
|
|
198
247
|
elif self._loss_type == "CE":
|
|
@@ -203,11 +252,20 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
203
252
|
msg = f"Not supported loss type: {self._loss_type}"
|
|
204
253
|
raise ValueError(msg)
|
|
205
254
|
|
|
255
|
+
if isinstance(batch, Bert4RecTrainingBatch):
|
|
256
|
+
warnings.warn(
|
|
257
|
+
"`Bert4RecTrainingBatch` class will be removed in future versions. "
|
|
258
|
+
"Instead, you should use simple dictionary",
|
|
259
|
+
DeprecationWarning,
|
|
260
|
+
stacklevel=2,
|
|
261
|
+
)
|
|
262
|
+
batch = batch.convert_to_dict()
|
|
263
|
+
|
|
206
264
|
loss = loss_func(
|
|
207
|
-
batch
|
|
208
|
-
batch
|
|
209
|
-
batch
|
|
210
|
-
batch
|
|
265
|
+
batch["inputs"],
|
|
266
|
+
batch["positive_labels"],
|
|
267
|
+
batch["pad_mask"],
|
|
268
|
+
batch["token_mask"],
|
|
211
269
|
)
|
|
212
270
|
|
|
213
271
|
return loss
|
|
@@ -253,7 +311,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
253
311
|
padding_mask: torch.BoolTensor,
|
|
254
312
|
tokens_mask: torch.BoolTensor,
|
|
255
313
|
) -> torch.Tensor:
|
|
256
|
-
|
|
314
|
+
positive_logits, negative_logits, *_ = self._get_sampled_logits(
|
|
257
315
|
feature_tensors, positive_labels, padding_mask, tokens_mask
|
|
258
316
|
)
|
|
259
317
|
|
|
@@ -300,7 +358,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
300
358
|
tokens_mask: torch.BoolTensor,
|
|
301
359
|
) -> torch.Tensor:
|
|
302
360
|
assert self._loss_sample_count is not None
|
|
303
|
-
|
|
361
|
+
positive_logits, negative_logits, positive_labels, negative_labels, vocab_size = self._get_sampled_logits(
|
|
304
362
|
feature_tensors, positive_labels, padding_mask, tokens_mask
|
|
305
363
|
)
|
|
306
364
|
n_negative_samples = min(self._loss_sample_count, vocab_size)
|
|
@@ -325,7 +383,7 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
325
383
|
padding_mask: torch.BoolTensor,
|
|
326
384
|
tokens_mask: torch.BoolTensor,
|
|
327
385
|
) -> torch.Tensor:
|
|
328
|
-
|
|
386
|
+
logits, labels = self._get_restricted_logits_for_ce_loss(
|
|
329
387
|
feature_tensors, positive_labels, padding_mask, tokens_mask
|
|
330
388
|
)
|
|
331
389
|
|
|
@@ -588,20 +646,20 @@ class Bert4Rec(lightning.LightningModule):
|
|
|
588
646
|
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
|
|
589
647
|
|
|
590
648
|
|
|
591
|
-
def _prepare_prediction_batch(
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
if batch.padding_mask.shape[1] > max_len:
|
|
649
|
+
def _prepare_prediction_batch(schema: TensorSchema, max_len: int, batch: dict) -> dict:
|
|
650
|
+
seq_len = batch["pad_mask"].shape[1]
|
|
651
|
+
if seq_len > max_len:
|
|
595
652
|
msg = (
|
|
596
653
|
f"The length of the submitted sequence "
|
|
597
654
|
"must not exceed the maximum length of the sequence. "
|
|
598
|
-
f"The length of the sequence is given {
|
|
655
|
+
f"The length of the sequence is given {seq_len}, "
|
|
599
656
|
f"while the maximum length is {max_len}"
|
|
600
657
|
)
|
|
601
658
|
raise ValueError(msg)
|
|
602
659
|
|
|
603
|
-
if
|
|
604
|
-
|
|
660
|
+
if seq_len < max_len:
|
|
661
|
+
padding_mask = batch["pad_mask"]
|
|
662
|
+
features = batch["inputs"].copy()
|
|
605
663
|
sequence_item_count = padding_mask.shape[1]
|
|
606
664
|
for feature_name, feature_tensor in features.items():
|
|
607
665
|
if schema[feature_name].is_cat:
|
|
@@ -618,5 +676,8 @@ def _prepare_prediction_batch(
|
|
|
618
676
|
).unsqueeze(-1)
|
|
619
677
|
padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
|
|
620
678
|
shifted_features, shifted_padding_mask, tokens_mask = _shift_features(schema, features, padding_mask)
|
|
621
|
-
|
|
679
|
+
|
|
680
|
+
batch["pad_mask"] = shifted_padding_mask
|
|
681
|
+
batch["inputs"] = shifted_features
|
|
682
|
+
batch["token_mask"] = tokens_mask
|
|
622
683
|
return batch
|
|
@@ -88,8 +88,8 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
88
88
|
def forward(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor) -> torch.Tensor:
|
|
89
89
|
"""
|
|
90
90
|
:param inputs: Batch of features.
|
|
91
|
-
:param pad_mask: Padding mask where 0 -
|
|
92
|
-
:param token_mask: Token mask where 0 -
|
|
91
|
+
:param pad_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
92
|
+
:param token_mask: Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
|
|
93
93
|
|
|
94
94
|
:returns: Calculated scores.
|
|
95
95
|
"""
|
|
@@ -107,12 +107,12 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
107
107
|
) -> torch.Tensor:
|
|
108
108
|
"""
|
|
109
109
|
:param inputs: Batch of features.
|
|
110
|
-
:param pad_mask: Padding mask where 0 -
|
|
111
|
-
:param token_mask: Token mask where 0 -
|
|
112
|
-
:param candidates_to_score: Item ids to calculate scores
|
|
113
|
-
|
|
110
|
+
:param pad_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
111
|
+
:param token_mask: Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
|
|
112
|
+
:param candidates_to_score: Item ids to calculate scores.\n
|
|
113
|
+
If ``None`` then predicts for all items. Default: ``None``.
|
|
114
114
|
|
|
115
|
-
:returns: Calculated scores among canditates_to_score items.
|
|
115
|
+
:returns: Calculated scores among ``canditates_to_score`` items.
|
|
116
116
|
"""
|
|
117
117
|
# final_emb: [B x E]
|
|
118
118
|
final_emb = self.get_query_embeddings(inputs, pad_mask, token_mask)
|
|
@@ -123,8 +123,8 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
123
123
|
"""
|
|
124
124
|
|
|
125
125
|
:param inputs (TensorMap): Batch of features.
|
|
126
|
-
:param pad_mask (torch.BoolTensor): Padding mask where 0 -
|
|
127
|
-
:param token_mask (torch.BoolTensor): Token mask where 0 -
|
|
126
|
+
:param pad_mask (torch.BoolTensor): Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
127
|
+
:param token_mask (torch.BoolTensor): Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
|
|
128
128
|
|
|
129
129
|
:returns: Output embeddings.
|
|
130
130
|
"""
|
|
@@ -158,8 +158,8 @@ class Bert4RecModel(torch.nn.Module):
|
|
|
158
158
|
def get_query_embeddings(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor):
|
|
159
159
|
"""
|
|
160
160
|
:param inputs: Batch of features.
|
|
161
|
-
:param pad_mask: Padding mask where 0 -
|
|
162
|
-
:param token_mask: Token mask where 0 -
|
|
161
|
+
:param pad_mask: Padding mask where 0 - ``<PAD>``, 1 - otherwise.
|
|
162
|
+
:param token_mask: Token mask where 0 - ``<MASK>`` tokens, 1 - otherwise.
|
|
163
163
|
|
|
164
164
|
:returns: Query embeddings.
|
|
165
165
|
"""
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import Generic, Optional, Protocol, TypeVar, Union, cast
|
|
3
4
|
|
|
4
5
|
import lightning
|
|
5
6
|
import torch
|
|
7
|
+
from typing_extensions import deprecated
|
|
6
8
|
|
|
7
9
|
from replay.models.nn.sequential import Bert4Rec
|
|
8
10
|
from replay.models.nn.sequential.postprocessors import BasePostProcessor
|
|
@@ -16,6 +18,7 @@ else:
|
|
|
16
18
|
SparkSession = MissingImport
|
|
17
19
|
|
|
18
20
|
|
|
21
|
+
@deprecated("`PredictionBatch` class is deprecated.", stacklevel=2)
|
|
19
22
|
class PredictionBatch(Protocol):
|
|
20
23
|
"""
|
|
21
24
|
Prediction callback batch
|
|
@@ -27,6 +30,10 @@ class PredictionBatch(Protocol):
|
|
|
27
30
|
_T = TypeVar("_T")
|
|
28
31
|
|
|
29
32
|
|
|
33
|
+
@deprecated(
|
|
34
|
+
"`BasePredictionCallback` class is deprecated. Use `replay.nn.lightning.callback.TopItemsCallbackBase` instead.",
|
|
35
|
+
stacklevel=2,
|
|
36
|
+
)
|
|
30
37
|
class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
31
38
|
"""
|
|
32
39
|
Base callback for prediction stage
|
|
@@ -48,6 +55,7 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
48
55
|
:param postprocessors: postprocessors to apply.
|
|
49
56
|
"""
|
|
50
57
|
super().__init__()
|
|
58
|
+
|
|
51
59
|
self.query_column = query_column
|
|
52
60
|
self.item_column = item_column
|
|
53
61
|
self.rating_column = rating_column
|
|
@@ -74,11 +82,14 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
74
82
|
trainer: lightning.Trainer, # noqa: ARG002
|
|
75
83
|
pl_module: lightning.LightningModule, # noqa: ARG002
|
|
76
84
|
outputs: torch.Tensor,
|
|
77
|
-
batch: PredictionBatch,
|
|
85
|
+
batch: Union[PredictionBatch, dict],
|
|
78
86
|
batch_idx: int, # noqa: ARG002
|
|
79
87
|
dataloader_idx: int = 0, # noqa: ARG002
|
|
80
88
|
) -> None:
|
|
81
|
-
query_ids, scores = self._compute_pipeline(
|
|
89
|
+
query_ids, scores = self._compute_pipeline(
|
|
90
|
+
batch["query_id"] if isinstance(batch, dict) else batch.query_id,
|
|
91
|
+
outputs,
|
|
92
|
+
)
|
|
82
93
|
top_scores, top_item_ids = torch.topk(scores, k=self._top_k, dim=1)
|
|
83
94
|
self._query_batches.append(query_ids)
|
|
84
95
|
self._item_batches.append(top_item_ids)
|
|
@@ -112,6 +123,10 @@ class BasePredictionCallback(lightning.Callback, Generic[_T]):
|
|
|
112
123
|
pass
|
|
113
124
|
|
|
114
125
|
|
|
126
|
+
@deprecated(
|
|
127
|
+
"`PandasPredictionCallback` class is deprecated. "
|
|
128
|
+
"Use `replay.nn.lightning.callback.PandasTopItemsCallback` instead."
|
|
129
|
+
)
|
|
115
130
|
class PandasPredictionCallback(BasePredictionCallback[PandasDataFrame]):
|
|
116
131
|
"""
|
|
117
132
|
Callback for predition stage with pandas data frame
|
|
@@ -133,6 +148,10 @@ class PandasPredictionCallback(BasePredictionCallback[PandasDataFrame]):
|
|
|
133
148
|
return prediction.explode([self.item_column, self.rating_column])
|
|
134
149
|
|
|
135
150
|
|
|
151
|
+
@deprecated(
|
|
152
|
+
"`PolarsPredictionCallback` class is deprecated. "
|
|
153
|
+
"Use `replay.nn.lightning.callback.PolarsTopItemsCallback` instead."
|
|
154
|
+
)
|
|
136
155
|
class PolarsPredictionCallback(BasePredictionCallback[PolarsDataFrame]):
|
|
137
156
|
"""
|
|
138
157
|
Callback for predition stage with polars data frame
|
|
@@ -154,6 +173,10 @@ class PolarsPredictionCallback(BasePredictionCallback[PolarsDataFrame]):
|
|
|
154
173
|
return prediction.explode([self.item_column, self.rating_column])
|
|
155
174
|
|
|
156
175
|
|
|
176
|
+
@deprecated(
|
|
177
|
+
"`SparkPredictionCallback` class is deprecated. "
|
|
178
|
+
"Use `replay.nn.lightning.callback.SparkTopItemsCallback` instead."
|
|
179
|
+
)
|
|
157
180
|
class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
158
181
|
"""
|
|
159
182
|
Callback for prediction stage with spark data frame
|
|
@@ -213,6 +236,10 @@ class SparkPredictionCallback(BasePredictionCallback[SparkDataFrame]):
|
|
|
213
236
|
return prediction
|
|
214
237
|
|
|
215
238
|
|
|
239
|
+
@deprecated(
|
|
240
|
+
"`TorchPredictionCallback` class is deprecated. "
|
|
241
|
+
"Use `replay.nn.lightning.callback.TorchTopItemsCallback` instead."
|
|
242
|
+
)
|
|
216
243
|
class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]]):
|
|
217
244
|
"""
|
|
218
245
|
Callback for predition stage with tuple of tensors
|
|
@@ -248,6 +275,10 @@ class TorchPredictionCallback(BasePredictionCallback[tuple[torch.LongTensor, tor
|
|
|
248
275
|
)
|
|
249
276
|
|
|
250
277
|
|
|
278
|
+
@deprecated(
|
|
279
|
+
"`QueryEmbeddingsPredictionCallback` class is deprecated. "
|
|
280
|
+
"Use `replay.nn.lightning.callback.HiddenStatesCallback` instead."
|
|
281
|
+
)
|
|
251
282
|
class QueryEmbeddingsPredictionCallback(lightning.Callback):
|
|
252
283
|
"""
|
|
253
284
|
Callback for prediction stage to get query embeddings.
|
|
@@ -266,15 +297,26 @@ class QueryEmbeddingsPredictionCallback(lightning.Callback):
|
|
|
266
297
|
trainer: lightning.Trainer, # noqa: ARG002
|
|
267
298
|
pl_module: lightning.LightningModule,
|
|
268
299
|
outputs: torch.Tensor, # noqa: ARG002
|
|
269
|
-
batch: PredictionBatch,
|
|
300
|
+
batch: Union[PredictionBatch, dict],
|
|
270
301
|
batch_idx: int, # noqa: ARG002
|
|
271
302
|
dataloader_idx: int = 0, # noqa: ARG002
|
|
272
303
|
) -> None:
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
304
|
+
if isinstance(batch, dict):
|
|
305
|
+
modified_batch = {
|
|
306
|
+
k: v
|
|
307
|
+
for k, v in batch.items()
|
|
308
|
+
if k in inspect.signature(pl_module._model.get_query_embeddings).parameters
|
|
309
|
+
}
|
|
310
|
+
query_embeddings = pl_module._model.get_query_embeddings(**modified_batch)
|
|
311
|
+
else:
|
|
312
|
+
args = [
|
|
313
|
+
batch.features,
|
|
314
|
+
batch.padding_mask,
|
|
315
|
+
]
|
|
316
|
+
if isinstance(pl_module, Bert4Rec):
|
|
317
|
+
args.append(batch.tokens_mask)
|
|
318
|
+
query_embeddings = pl_module._model.get_query_embeddings(*args)
|
|
276
319
|
|
|
277
|
-
query_embeddings = pl_module._model.get_query_embeddings(*args)
|
|
278
320
|
self._embeddings_per_batch.append(query_embeddings)
|
|
279
321
|
|
|
280
322
|
def get_result(self):
|