replay-rec 0.20.3__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +11 -0
- replay/data/nn/__init__.py +3 -0
- replay/data/nn/parquet/__init__.py +22 -0
- replay/data/nn/parquet/collate.py +29 -0
- replay/data/nn/parquet/constants/__init__.py +0 -0
- replay/data/nn/parquet/constants/batches.py +8 -0
- replay/data/nn/parquet/constants/device.py +3 -0
- replay/data/nn/parquet/constants/filesystem.py +3 -0
- replay/data/nn/parquet/constants/metadata.py +5 -0
- replay/data/nn/parquet/fixed_batch_dataset.py +157 -0
- replay/data/nn/parquet/impl/__init__.py +0 -0
- replay/data/nn/parquet/impl/array_1d_column.py +140 -0
- replay/data/nn/parquet/impl/array_2d_column.py +160 -0
- replay/data/nn/parquet/impl/column_protocol.py +17 -0
- replay/data/nn/parquet/impl/indexing.py +123 -0
- replay/data/nn/parquet/impl/masking.py +20 -0
- replay/data/nn/parquet/impl/named_columns.py +100 -0
- replay/data/nn/parquet/impl/numeric_column.py +110 -0
- replay/data/nn/parquet/impl/utils.py +17 -0
- replay/data/nn/parquet/info/__init__.py +0 -0
- replay/data/nn/parquet/info/distributed_info.py +40 -0
- replay/data/nn/parquet/info/partitioning.py +132 -0
- replay/data/nn/parquet/info/replicas.py +67 -0
- replay/data/nn/parquet/info/worker_info.py +43 -0
- replay/data/nn/parquet/iterable_dataset.py +119 -0
- replay/data/nn/parquet/iterator.py +61 -0
- replay/data/nn/parquet/metadata/__init__.py +19 -0
- replay/data/nn/parquet/metadata/metadata.py +116 -0
- replay/data/nn/parquet/parquet_dataset.py +176 -0
- replay/data/nn/parquet/parquet_module.py +178 -0
- replay/data/nn/parquet/partitioned_iterable_dataset.py +56 -0
- replay/data/nn/parquet/utils/__init__.py +0 -0
- replay/data/nn/parquet/utils/compute_length.py +66 -0
- replay/data/nn/schema.py +12 -14
- replay/data/nn/sequence_tokenizer.py +5 -0
- replay/data/nn/sequential_dataset.py +4 -0
- replay/data/nn/torch_sequential_dataset.py +5 -0
- replay/data/utils/__init__.py +0 -0
- replay/data/utils/batching.py +69 -0
- replay/data/utils/typing/__init__.py +0 -0
- replay/data/utils/typing/dtype.py +65 -0
- replay/metrics/torch_metrics_builder.py +20 -14
- replay/models/nn/loss/sce.py +2 -7
- replay/models/nn/optimizer_utils/__init__.py +6 -1
- replay/models/nn/optimizer_utils/optimizer_factory.py +15 -0
- replay/models/nn/sequential/bert4rec/dataset.py +70 -29
- replay/models/nn/sequential/bert4rec/lightning.py +97 -36
- replay/models/nn/sequential/bert4rec/model.py +11 -11
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +50 -8
- replay/models/nn/sequential/callbacks/validation_callback.py +23 -6
- replay/models/nn/sequential/compiled/base_compiled_model.py +12 -4
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +15 -5
- replay/models/nn/sequential/compiled/sasrec_compiled.py +16 -7
- replay/models/nn/sequential/postprocessors/_base.py +5 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +4 -0
- replay/models/nn/sequential/sasrec/dataset.py +81 -26
- replay/models/nn/sequential/sasrec/lightning.py +86 -24
- replay/models/nn/sequential/sasrec/model.py +14 -9
- replay/nn/__init__.py +8 -0
- replay/nn/agg.py +109 -0
- replay/nn/attention.py +158 -0
- replay/nn/embedding.py +283 -0
- replay/nn/ffn.py +135 -0
- replay/nn/head.py +49 -0
- replay/nn/lightning/__init__.py +1 -0
- replay/nn/lightning/callback/__init__.py +9 -0
- replay/nn/lightning/callback/metrics_callback.py +183 -0
- replay/nn/lightning/callback/predictions_callback.py +314 -0
- replay/nn/lightning/module.py +123 -0
- replay/nn/lightning/optimizer.py +60 -0
- replay/nn/lightning/postprocessor/__init__.py +2 -0
- replay/nn/lightning/postprocessor/_base.py +51 -0
- replay/nn/lightning/postprocessor/seen_items.py +83 -0
- replay/nn/lightning/scheduler.py +91 -0
- replay/nn/loss/__init__.py +22 -0
- replay/nn/loss/base.py +197 -0
- replay/nn/loss/bce.py +216 -0
- replay/nn/loss/ce.py +317 -0
- replay/nn/loss/login_ce.py +373 -0
- replay/nn/loss/logout_ce.py +230 -0
- replay/nn/mask.py +87 -0
- replay/nn/normalization.py +9 -0
- replay/nn/output.py +37 -0
- replay/nn/sequential/__init__.py +9 -0
- replay/nn/sequential/sasrec/__init__.py +7 -0
- replay/nn/sequential/sasrec/agg.py +53 -0
- replay/nn/sequential/sasrec/diff_transformer.py +125 -0
- replay/nn/sequential/sasrec/model.py +377 -0
- replay/nn/sequential/sasrec/transformer.py +107 -0
- replay/nn/sequential/twotower/__init__.py +2 -0
- replay/nn/sequential/twotower/model.py +674 -0
- replay/nn/sequential/twotower/reader.py +89 -0
- replay/nn/transform/__init__.py +22 -0
- replay/nn/transform/copy.py +38 -0
- replay/nn/transform/grouping.py +39 -0
- replay/nn/transform/negative_sampling.py +182 -0
- replay/nn/transform/next_token.py +100 -0
- replay/nn/transform/rename.py +33 -0
- replay/nn/transform/reshape.py +41 -0
- replay/nn/transform/sequence_roll.py +48 -0
- replay/nn/transform/template/__init__.py +2 -0
- replay/nn/transform/template/sasrec.py +53 -0
- replay/nn/transform/template/twotower.py +22 -0
- replay/nn/transform/token_mask.py +69 -0
- replay/nn/transform/trim.py +51 -0
- replay/nn/utils.py +28 -0
- replay/preprocessing/filters.py +128 -0
- replay/preprocessing/label_encoder.py +36 -33
- replay/preprocessing/utils.py +209 -0
- replay/splitters/__init__.py +1 -0
- replay/splitters/random_next_n_splitter.py +224 -0
- replay/utils/common.py +10 -4
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/METADATA +3 -3
- replay_rec-0.21.0.dist-info/RECORD +223 -0
- replay_rec-0.20.3.dist-info/RECORD +0 -138
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.3.dist-info → replay_rec-0.21.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from typing import Any, Literal, Optional, Protocol
|
|
1
|
+
from typing import Any, Literal, Optional, Protocol, Union
|
|
2
2
|
|
|
3
3
|
import lightning
|
|
4
4
|
import torch
|
|
5
5
|
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
|
6
|
+
from typing_extensions import deprecated
|
|
6
7
|
|
|
7
8
|
from replay.metrics.torch_metrics_builder import TorchMetricsBuilder, metrics_to_df
|
|
8
9
|
from replay.models.nn.sequential.postprocessors import BasePostProcessor
|
|
@@ -18,6 +19,7 @@ CallbackMetricName = Literal[
|
|
|
18
19
|
]
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
@deprecated("`ValidationBatch` class is deprecated.", stacklevel=2)
|
|
21
23
|
class ValidationBatch(Protocol):
|
|
22
24
|
"""
|
|
23
25
|
Validation callback batch
|
|
@@ -28,12 +30,19 @@ class ValidationBatch(Protocol):
|
|
|
28
30
|
train: torch.LongTensor
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
@deprecated(
|
|
34
|
+
"`ValidationMetricsCallback` class is deprecated. "
|
|
35
|
+
"Use `replay.nn.lightning.callback.ComputeMetricsCallback` instead."
|
|
36
|
+
)
|
|
31
37
|
class ValidationMetricsCallback(lightning.Callback):
|
|
32
38
|
"""
|
|
33
39
|
Callback for validation and testing stages.
|
|
34
40
|
|
|
35
41
|
If multiple validation/testing dataloaders are used,
|
|
36
42
|
the suffix of the metric name will contain the serial number of the dataloader.
|
|
43
|
+
|
|
44
|
+
For the callback to work correctly, the batch must contain the `query_id` and `ground_truth` keys.
|
|
45
|
+
If you want to calculate the coverage or novelty metrics then the batch must additionally contain the `train` key.
|
|
37
46
|
"""
|
|
38
47
|
|
|
39
48
|
def __init__(
|
|
@@ -95,7 +104,7 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
95
104
|
trainer: lightning.Trainer,
|
|
96
105
|
pl_module: lightning.LightningModule,
|
|
97
106
|
outputs: torch.Tensor,
|
|
98
|
-
batch: ValidationBatch,
|
|
107
|
+
batch: Union[ValidationBatch, dict],
|
|
99
108
|
batch_idx: int,
|
|
100
109
|
dataloader_idx: int = 0,
|
|
101
110
|
) -> None:
|
|
@@ -106,7 +115,7 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
106
115
|
trainer: lightning.Trainer,
|
|
107
116
|
pl_module: lightning.LightningModule,
|
|
108
117
|
outputs: torch.Tensor,
|
|
109
|
-
batch: ValidationBatch,
|
|
118
|
+
batch: Union[ValidationBatch, dict],
|
|
110
119
|
batch_idx: int,
|
|
111
120
|
dataloader_idx: int = 0,
|
|
112
121
|
) -> None: # pragma: no cover
|
|
@@ -117,13 +126,21 @@ class ValidationMetricsCallback(lightning.Callback):
|
|
|
117
126
|
trainer: lightning.Trainer, # noqa: ARG002
|
|
118
127
|
pl_module: lightning.LightningModule,
|
|
119
128
|
outputs: torch.Tensor,
|
|
120
|
-
batch: ValidationBatch,
|
|
129
|
+
batch: Union[ValidationBatch, dict],
|
|
121
130
|
batch_idx: int,
|
|
122
131
|
dataloader_idx: int,
|
|
123
132
|
) -> None:
|
|
124
|
-
_, seen_scores, seen_ground_truth = self._compute_pipeline(
|
|
133
|
+
_, seen_scores, seen_ground_truth = self._compute_pipeline(
|
|
134
|
+
batch["query_id"] if isinstance(batch, dict) else batch.query_id,
|
|
135
|
+
outputs,
|
|
136
|
+
batch["ground_truth"] if isinstance(batch, dict) else batch.ground_truth,
|
|
137
|
+
)
|
|
125
138
|
sampled_items = torch.topk(seen_scores, k=self._metrics_builders[dataloader_idx].max_k, dim=1).indices
|
|
126
|
-
self._metrics_builders[dataloader_idx].add_prediction(
|
|
139
|
+
self._metrics_builders[dataloader_idx].add_prediction(
|
|
140
|
+
sampled_items,
|
|
141
|
+
seen_ground_truth,
|
|
142
|
+
batch.get("train") if isinstance(batch, dict) else batch.train,
|
|
143
|
+
)
|
|
127
144
|
|
|
128
145
|
if batch_idx + 1 == self._dataloaders_size[dataloader_idx]:
|
|
129
146
|
pl_module.log_dict(
|
|
@@ -101,17 +101,24 @@ class BaseCompiledModel:
|
|
|
101
101
|
)
|
|
102
102
|
raise ValueError(msg)
|
|
103
103
|
|
|
104
|
-
def
|
|
104
|
+
def _validate_predict_input(
|
|
105
|
+
self,
|
|
106
|
+
batch: Any,
|
|
107
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
108
|
+
padding_mask_key_name: str = "padding_mask",
|
|
109
|
+
) -> None:
|
|
105
110
|
if self._num_candidates_to_score is None and candidates_to_score is not None:
|
|
106
111
|
msg = (
|
|
107
112
|
"If ``num_candidates_to_score`` is None, "
|
|
108
113
|
"it is impossible to infer the model with passed ``candidates_to_score``."
|
|
109
114
|
)
|
|
110
115
|
raise ValueError(msg)
|
|
111
|
-
|
|
112
|
-
|
|
116
|
+
input_batch_size = (
|
|
117
|
+
batch[padding_mask_key_name].shape[0] if isinstance(batch, dict) else batch.padding_mask.shape[0]
|
|
118
|
+
)
|
|
119
|
+
if self._batch_size != -1 and input_batch_size != self._batch_size:
|
|
113
120
|
msg = (
|
|
114
|
-
f"The batch is smaller
|
|
121
|
+
f"The batch is smaller than defined batch_size={self._batch_size}. "
|
|
115
122
|
"It is impossible to infer the model with dynamic batch size in ``mode`` = ``batch``. "
|
|
116
123
|
"Use ``mode`` = ``dynamic_batch_size``."
|
|
117
124
|
)
|
|
@@ -215,6 +222,7 @@ class BaseCompiledModel:
|
|
|
215
222
|
input_names=model_input_names,
|
|
216
223
|
output_names=["scores"],
|
|
217
224
|
dynamic_axes=model_dynamic_axes_in_input,
|
|
225
|
+
dynamo=False,
|
|
218
226
|
)
|
|
219
227
|
del lightning_model
|
|
220
228
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import pathlib
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import Optional, Union, get_args
|
|
3
4
|
|
|
4
5
|
import openvino as ov
|
|
@@ -39,7 +40,7 @@ class Bert4RecCompiled(BaseCompiledModel):
|
|
|
39
40
|
|
|
40
41
|
def predict(
|
|
41
42
|
self,
|
|
42
|
-
batch: Bert4RecPredictionBatch,
|
|
43
|
+
batch: Union[Bert4RecPredictionBatch, dict],
|
|
43
44
|
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
44
45
|
) -> torch.Tensor:
|
|
45
46
|
"""
|
|
@@ -51,13 +52,22 @@ class Bert4RecCompiled(BaseCompiledModel):
|
|
|
51
52
|
|
|
52
53
|
:return: Tensor with scores.
|
|
53
54
|
"""
|
|
54
|
-
self.
|
|
55
|
+
self._validate_predict_input(batch, candidates_to_score, "pad_mask")
|
|
56
|
+
|
|
57
|
+
if isinstance(batch, Bert4RecPredictionBatch):
|
|
58
|
+
warnings.warn(
|
|
59
|
+
"`Bert4RecPredictionBatch` class will be removed in future versions. "
|
|
60
|
+
"Instead, you should use simple dictionary",
|
|
61
|
+
DeprecationWarning,
|
|
62
|
+
stacklevel=2,
|
|
63
|
+
)
|
|
64
|
+
batch = batch.convert_to_dict()
|
|
55
65
|
|
|
56
66
|
batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
|
|
57
67
|
model_inputs = {
|
|
58
|
-
self._inputs_names[0]: batch
|
|
59
|
-
self._inputs_names[1]: batch
|
|
60
|
-
self._inputs_names[2]: batch
|
|
68
|
+
self._inputs_names[0]: batch["inputs"][self._inputs_names[0]],
|
|
69
|
+
self._inputs_names[1]: batch["pad_mask"],
|
|
70
|
+
self._inputs_names[2]: batch["token_mask"],
|
|
61
71
|
}
|
|
62
72
|
if self._num_candidates_to_score is not None:
|
|
63
73
|
self._validate_candidates_to_score(candidates_to_score)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import pathlib
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import Optional, Union, get_args
|
|
3
4
|
|
|
4
5
|
import openvino as ov
|
|
@@ -39,7 +40,7 @@ class SasRecCompiled(BaseCompiledModel):
|
|
|
39
40
|
|
|
40
41
|
def predict(
|
|
41
42
|
self,
|
|
42
|
-
batch: SasRecPredictionBatch,
|
|
43
|
+
batch: Union[SasRecPredictionBatch, dict],
|
|
43
44
|
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
44
45
|
) -> torch.Tensor:
|
|
45
46
|
"""
|
|
@@ -51,12 +52,21 @@ class SasRecCompiled(BaseCompiledModel):
|
|
|
51
52
|
|
|
52
53
|
:return: Tensor with scores.
|
|
53
54
|
"""
|
|
54
|
-
self.
|
|
55
|
+
self._validate_predict_input(batch, candidates_to_score)
|
|
56
|
+
|
|
57
|
+
if isinstance(batch, SasRecPredictionBatch):
|
|
58
|
+
warnings.warn(
|
|
59
|
+
"`SasRecPredictionBatch` class will be removed in future versions. "
|
|
60
|
+
"Instead, you should use simple dictionary",
|
|
61
|
+
DeprecationWarning,
|
|
62
|
+
stacklevel=2,
|
|
63
|
+
)
|
|
64
|
+
batch = batch.convert_to_dict()
|
|
55
65
|
|
|
56
66
|
batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
|
|
57
67
|
model_inputs = {
|
|
58
|
-
self._inputs_names[0]: batch
|
|
59
|
-
self._inputs_names[1]: batch
|
|
68
|
+
self._inputs_names[0]: batch["feature_tensor"][self._inputs_names[0]],
|
|
69
|
+
self._inputs_names[1]: batch["padding_mask"],
|
|
60
70
|
}
|
|
61
71
|
if self._num_candidates_to_score is not None:
|
|
62
72
|
self._validate_candidates_to_score(candidates_to_score)
|
|
@@ -77,15 +87,14 @@ class SasRecCompiled(BaseCompiledModel):
|
|
|
77
87
|
Model compilation.
|
|
78
88
|
|
|
79
89
|
:param model: Path to lightning SasRec model saved in .ckpt format or the SasRec object itself.
|
|
80
|
-
:param mode: Inference mode, defines shape of inputs
|
|
81
|
-
Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
|
|
90
|
+
:param mode: Inference mode, defines shape of inputs.\n
|
|
82
91
|
``one_query`` - sets input shape to [1, max_seq_len]\n
|
|
83
92
|
``batch`` - sets input shape to [batch_size, max_seq_len]\n
|
|
84
93
|
``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
|
|
85
94
|
Default: ``one_query``.
|
|
86
95
|
:param batch_size: Batch size, required for ``batch`` mode.
|
|
87
96
|
Default: ``None``.
|
|
88
|
-
:param num_candidates_to_score: Number of item ids to calculate scores
|
|
97
|
+
:param num_candidates_to_score: Number of item ids to calculate scores.\n
|
|
89
98
|
Could be one of [``None``, ``-1``, ``N``].\n
|
|
90
99
|
``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
|
|
91
100
|
``N`` - sets candidates_to_score shape to [1, N]\n
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from typing_extensions import deprecated
|
|
4
5
|
|
|
5
6
|
|
|
7
|
+
@deprecated(
|
|
8
|
+
"`BasePostProcessor` class is deprecated. Use `replay.nn.lightning.postprocessor.PostprocessorBase` instead.",
|
|
9
|
+
stacklevel=2,
|
|
10
|
+
)
|
|
6
11
|
class BasePostProcessor(abc.ABC): # pragma: no cover
|
|
7
12
|
"""
|
|
8
13
|
Abstract base class for post processor
|
|
@@ -3,12 +3,14 @@ from typing import Optional, Union, cast
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import torch
|
|
6
|
+
from typing_extensions import deprecated
|
|
6
7
|
|
|
7
8
|
from replay.data.nn import SequentialDataset
|
|
8
9
|
|
|
9
10
|
from ._base import BasePostProcessor
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
@deprecated("`RemoveSeenItems` class is deprecated. Use `replay.nn.lightning.postprocessor.SeenItemsFilter` instead.")
|
|
12
14
|
class RemoveSeenItems(BasePostProcessor):
|
|
13
15
|
"""
|
|
14
16
|
Filters out the items that already have been seen in dataset.
|
|
@@ -16,6 +18,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
16
18
|
|
|
17
19
|
def __init__(self, sequential: SequentialDataset) -> None:
|
|
18
20
|
super().__init__()
|
|
21
|
+
|
|
19
22
|
self._sequential = sequential
|
|
20
23
|
self._apply_candidates = False
|
|
21
24
|
self._candidates = None
|
|
@@ -107,6 +110,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
107
110
|
self._candidates = candidates
|
|
108
111
|
|
|
109
112
|
|
|
113
|
+
@deprecated("`SampleItems` class is deprecated.")
|
|
110
114
|
class SampleItems(BasePostProcessor):
|
|
111
115
|
"""
|
|
112
116
|
Generates negative samples to compute sampled metrics
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
from typing import NamedTuple, Optional
|
|
1
|
+
from typing import NamedTuple, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch.utils.data import Dataset as TorchDataset
|
|
5
|
+
from typing_extensions import deprecated
|
|
5
6
|
|
|
6
7
|
from replay.data.nn import (
|
|
7
8
|
MutableTensorMap,
|
|
@@ -12,6 +13,10 @@ from replay.data.nn import (
|
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
@deprecated(
|
|
17
|
+
"`SasRecTrainingBatch` class is deprecated.",
|
|
18
|
+
stacklevel=2,
|
|
19
|
+
)
|
|
15
20
|
class SasRecTrainingBatch(NamedTuple):
|
|
16
21
|
"""
|
|
17
22
|
Batch of data for training.
|
|
@@ -24,10 +29,25 @@ class SasRecTrainingBatch(NamedTuple):
|
|
|
24
29
|
labels: torch.LongTensor
|
|
25
30
|
labels_padding_mask: torch.BoolTensor
|
|
26
31
|
|
|
32
|
+
def convert_to_dict(self) -> dict:
|
|
33
|
+
return {
|
|
34
|
+
"query_id": self.query_id,
|
|
35
|
+
"feature_tensor": self.features,
|
|
36
|
+
"padding_mask": self.padding_mask,
|
|
37
|
+
"positive_labels": self.labels,
|
|
38
|
+
"target_padding_mask": self.labels_padding_mask,
|
|
39
|
+
}
|
|
40
|
+
|
|
27
41
|
|
|
42
|
+
@deprecated("`SasRecTrainingDataset` class is deprecated. Use `replay.data.nn.ParquetModule` instead.")
|
|
28
43
|
class SasRecTrainingDataset(TorchDataset):
|
|
29
44
|
"""
|
|
30
|
-
Dataset that generates samples to train SasRec
|
|
45
|
+
Dataset that generates samples to train SasRec model.
|
|
46
|
+
|
|
47
|
+
As a result of the dataset iteration, a dictionary is formed.
|
|
48
|
+
The keys in the dictionary match the names of the arguments in the model's `forward` function.
|
|
49
|
+
There are also additional keys needed to calculate losses - 'positive_labels`, `target_padding_mask`.
|
|
50
|
+
The `query_id` key is required for possible debugging and calling additional lightning callbacks.
|
|
31
51
|
"""
|
|
32
52
|
|
|
33
53
|
def __init__(
|
|
@@ -81,7 +101,7 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
81
101
|
def __len__(self) -> int:
|
|
82
102
|
return len(self._inner)
|
|
83
103
|
|
|
84
|
-
def __getitem__(self, index: int) ->
|
|
104
|
+
def __getitem__(self, index: int) -> dict:
|
|
85
105
|
query_id, padding_mask, features = self._inner[index]
|
|
86
106
|
|
|
87
107
|
assert self._label_feature_name
|
|
@@ -97,15 +117,19 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
97
117
|
|
|
98
118
|
output_features_padding_mask = padding_mask[: -self._sequence_shift]
|
|
99
119
|
|
|
100
|
-
return
|
|
101
|
-
query_id
|
|
102
|
-
|
|
103
|
-
padding_mask
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
120
|
+
return {
|
|
121
|
+
"query_id": query_id,
|
|
122
|
+
"feature_tensor": output_features,
|
|
123
|
+
"padding_mask": output_features_padding_mask,
|
|
124
|
+
"positive_labels": labels,
|
|
125
|
+
"target_padding_mask": labels_padding_mask,
|
|
126
|
+
}
|
|
107
127
|
|
|
108
128
|
|
|
129
|
+
@deprecated(
|
|
130
|
+
"`SasRecPredictionBatch` class is deprecated.",
|
|
131
|
+
stacklevel=2,
|
|
132
|
+
)
|
|
109
133
|
class SasRecPredictionBatch(NamedTuple):
|
|
110
134
|
"""
|
|
111
135
|
Batch of data for model inference.
|
|
@@ -116,10 +140,22 @@ class SasRecPredictionBatch(NamedTuple):
|
|
|
116
140
|
padding_mask: torch.BoolTensor
|
|
117
141
|
features: TensorMap
|
|
118
142
|
|
|
143
|
+
def convert_to_dict(self) -> dict:
|
|
144
|
+
return {
|
|
145
|
+
"query_id": self.query_id,
|
|
146
|
+
"feature_tensor": self.features,
|
|
147
|
+
"padding_mask": self.padding_mask,
|
|
148
|
+
}
|
|
149
|
+
|
|
119
150
|
|
|
151
|
+
@deprecated("`SasRecPredictionDataset` class is deprecated. Use `replay.data.nn.ParquetModule` instead.")
|
|
120
152
|
class SasRecPredictionDataset(TorchDataset):
|
|
121
153
|
"""
|
|
122
|
-
Dataset that generates samples to infer SasRec
|
|
154
|
+
Dataset that generates samples to infer SasRec model
|
|
155
|
+
|
|
156
|
+
As a result of the dataset iteration, a dictionary is formed.
|
|
157
|
+
The keys in the dictionary match the names of the arguments in the model's `forward` function.
|
|
158
|
+
The `query_id` key is required for possible debugging and calling additional lightning callbacks.
|
|
123
159
|
"""
|
|
124
160
|
|
|
125
161
|
def __init__(
|
|
@@ -143,15 +179,19 @@ class SasRecPredictionDataset(TorchDataset):
|
|
|
143
179
|
def __len__(self) -> int:
|
|
144
180
|
return len(self._inner)
|
|
145
181
|
|
|
146
|
-
def __getitem__(self, index: int) ->
|
|
182
|
+
def __getitem__(self, index: int) -> dict:
|
|
147
183
|
query_id, padding_mask, features = self._inner[index]
|
|
148
|
-
return
|
|
149
|
-
query_id
|
|
150
|
-
padding_mask
|
|
151
|
-
features
|
|
152
|
-
|
|
184
|
+
return {
|
|
185
|
+
"query_id": query_id,
|
|
186
|
+
"padding_mask": padding_mask,
|
|
187
|
+
"feature_tensor": features,
|
|
188
|
+
}
|
|
153
189
|
|
|
154
190
|
|
|
191
|
+
@deprecated(
|
|
192
|
+
"`SasRecValidationBatch` class is deprecated.",
|
|
193
|
+
stacklevel=2,
|
|
194
|
+
)
|
|
155
195
|
class SasRecValidationBatch(NamedTuple):
|
|
156
196
|
"""
|
|
157
197
|
Batch of data for validation.
|
|
@@ -164,10 +204,25 @@ class SasRecValidationBatch(NamedTuple):
|
|
|
164
204
|
ground_truth: torch.LongTensor
|
|
165
205
|
train: torch.LongTensor
|
|
166
206
|
|
|
207
|
+
def convert_to_dict(self) -> dict:
|
|
208
|
+
return {
|
|
209
|
+
"query_id": self.query_id,
|
|
210
|
+
"feature_tensor": self.features,
|
|
211
|
+
"padding_mask": self.padding_mask,
|
|
212
|
+
"ground_truth": self.ground_truth,
|
|
213
|
+
"train": self.train,
|
|
214
|
+
}
|
|
167
215
|
|
|
216
|
+
|
|
217
|
+
@deprecated("`SasRecValidationDataset` class is deprecated. Use `replay.data.nn.ParquetModule` instead.")
|
|
168
218
|
class SasRecValidationDataset(TorchDataset):
|
|
169
219
|
"""
|
|
170
|
-
Dataset that generates samples to infer and validate SasRec
|
|
220
|
+
Dataset that generates samples to infer and validate SasRec model.
|
|
221
|
+
|
|
222
|
+
As a result of the dataset iteration, a dictionary is formed.
|
|
223
|
+
The keys in the dictionary match the names of the arguments in the model's `forward` function.
|
|
224
|
+
The `query_id` key is required for possible debugging and calling additional lightning callbacks.
|
|
225
|
+
Keys 'ground_truth` and `train` keys are required for metrics calculation on validation stage.
|
|
171
226
|
"""
|
|
172
227
|
|
|
173
228
|
def __init__(
|
|
@@ -202,12 +257,12 @@ class SasRecValidationDataset(TorchDataset):
|
|
|
202
257
|
def __len__(self) -> int:
|
|
203
258
|
return len(self._inner)
|
|
204
259
|
|
|
205
|
-
def __getitem__(self, index: int) ->
|
|
260
|
+
def __getitem__(self, index: int) -> dict:
|
|
206
261
|
query_id, padding_mask, features, ground_truth, train = self._inner[index]
|
|
207
|
-
return
|
|
208
|
-
query_id
|
|
209
|
-
padding_mask
|
|
210
|
-
features
|
|
211
|
-
ground_truth
|
|
212
|
-
train
|
|
213
|
-
|
|
262
|
+
return {
|
|
263
|
+
"query_id": query_id,
|
|
264
|
+
"padding_mask": padding_mask,
|
|
265
|
+
"feature_tensor": features,
|
|
266
|
+
"ground_truth": ground_truth,
|
|
267
|
+
"train": train,
|
|
268
|
+
}
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import Any, Literal, Optional, Union, cast
|
|
3
4
|
|
|
4
5
|
import lightning
|
|
5
6
|
import torch
|
|
7
|
+
from typing_extensions import deprecated
|
|
6
8
|
|
|
7
9
|
from replay.data.nn import TensorMap, TensorSchema
|
|
8
10
|
from replay.models.nn.loss import ScalableCrossEntropyLoss, SCEParams
|
|
@@ -12,6 +14,11 @@ from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidatio
|
|
|
12
14
|
from .model import SasRecModel
|
|
13
15
|
|
|
14
16
|
|
|
17
|
+
@deprecated(
|
|
18
|
+
"`SasRec` class is deprecated. "
|
|
19
|
+
"Use `replay.nn.sequential.SasRec` "
|
|
20
|
+
"and `replay.nn.lightning.LightningModule` instead."
|
|
21
|
+
)
|
|
15
22
|
class SasRec(lightning.LightningModule):
|
|
16
23
|
"""
|
|
17
24
|
SASRec Lightning module.
|
|
@@ -54,9 +61,9 @@ class SasRec(lightning.LightningModule):
|
|
|
54
61
|
Default: ``False``.
|
|
55
62
|
:param time_span: Time span value.
|
|
56
63
|
Default: ``256``.
|
|
57
|
-
:param loss_type: Loss type.
|
|
64
|
+
:param loss_type: Loss type.
|
|
58
65
|
Default: ``CE``.
|
|
59
|
-
:param loss_sample_count
|
|
66
|
+
:param loss_sample_count: Sample count to calculate loss.
|
|
60
67
|
Suitable for ``"CE"`` and ``"BCE"`` loss functions.
|
|
61
68
|
Default: ``None``.
|
|
62
69
|
:param negative_sampling_strategy: Negative sampling strategy to calculate loss on sampled negatives.
|
|
@@ -74,6 +81,7 @@ class SasRec(lightning.LightningModule):
|
|
|
74
81
|
"""
|
|
75
82
|
super().__init__()
|
|
76
83
|
self.save_hyperparameters()
|
|
84
|
+
|
|
77
85
|
self._model = SasRecModel(
|
|
78
86
|
schema=tensor_schema,
|
|
79
87
|
num_blocks=block_count,
|
|
@@ -102,7 +110,7 @@ class SasRec(lightning.LightningModule):
|
|
|
102
110
|
self._vocab_size = item_count
|
|
103
111
|
self.candidates_to_score = None
|
|
104
112
|
|
|
105
|
-
def training_step(self, batch: SasRecTrainingBatch, batch_idx: int) -> torch.Tensor:
|
|
113
|
+
def training_step(self, batch: Union[SasRecTrainingBatch, dict], batch_idx: int) -> torch.Tensor:
|
|
106
114
|
"""
|
|
107
115
|
:param batch (SasRecTrainingBatch): Batch of training data.
|
|
108
116
|
:param batch_idx (int): Batch index.
|
|
@@ -117,7 +125,7 @@ class SasRec(lightning.LightningModule):
|
|
|
117
125
|
|
|
118
126
|
def predict_step(
|
|
119
127
|
self,
|
|
120
|
-
batch: SasRecPredictionBatch,
|
|
128
|
+
batch: Union[SasRecPredictionBatch, dict],
|
|
121
129
|
batch_idx: int, # noqa: ARG002
|
|
122
130
|
dataloader_idx: int = 0, # noqa: ARG002
|
|
123
131
|
) -> torch.Tensor:
|
|
@@ -128,12 +136,23 @@ class SasRec(lightning.LightningModule):
|
|
|
128
136
|
|
|
129
137
|
:returns: Calculated scores.
|
|
130
138
|
"""
|
|
139
|
+
if isinstance(batch, SasRecPredictionBatch):
|
|
140
|
+
warnings.warn(
|
|
141
|
+
"`SasRecPredictionBatch` class will be removed in future versions. "
|
|
142
|
+
"Instead, you should use simple dictionary",
|
|
143
|
+
DeprecationWarning,
|
|
144
|
+
stacklevel=2,
|
|
145
|
+
)
|
|
146
|
+
batch = batch.convert_to_dict()
|
|
131
147
|
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
132
|
-
return self._model_predict(
|
|
148
|
+
return self._model_predict(
|
|
149
|
+
feature_tensors=batch["feature_tensor"],
|
|
150
|
+
padding_mask=batch["padding_mask"],
|
|
151
|
+
)
|
|
133
152
|
|
|
134
153
|
def predict(
|
|
135
154
|
self,
|
|
136
|
-
batch: SasRecPredictionBatch,
|
|
155
|
+
batch: Union[SasRecPredictionBatch, dict],
|
|
137
156
|
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
138
157
|
) -> torch.Tensor:
|
|
139
158
|
"""
|
|
@@ -143,8 +162,20 @@ class SasRec(lightning.LightningModule):
|
|
|
143
162
|
|
|
144
163
|
:returns: Calculated scores.
|
|
145
164
|
"""
|
|
165
|
+
if isinstance(batch, SasRecPredictionBatch):
|
|
166
|
+
warnings.warn(
|
|
167
|
+
"`SasRecPredictionBatch` class will be removed in future versions. "
|
|
168
|
+
"Instead, you should use simple dictionary",
|
|
169
|
+
DeprecationWarning,
|
|
170
|
+
stacklevel=2,
|
|
171
|
+
)
|
|
172
|
+
batch = batch.convert_to_dict()
|
|
146
173
|
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
147
|
-
return self._model_predict(
|
|
174
|
+
return self._model_predict(
|
|
175
|
+
feature_tensors=batch["feature_tensor"],
|
|
176
|
+
padding_mask=batch["padding_mask"],
|
|
177
|
+
candidates_to_score=candidates_to_score,
|
|
178
|
+
)
|
|
148
179
|
|
|
149
180
|
def forward(
|
|
150
181
|
self,
|
|
@@ -164,7 +195,7 @@ class SasRec(lightning.LightningModule):
|
|
|
164
195
|
|
|
165
196
|
def validation_step(
|
|
166
197
|
self,
|
|
167
|
-
batch: SasRecValidationBatch,
|
|
198
|
+
batch: Union[SasRecValidationBatch, dict],
|
|
168
199
|
batch_idx: int, # noqa: ARG002
|
|
169
200
|
dataloader_idx: int = 0, # noqa: ARG002
|
|
170
201
|
) -> torch.Tensor:
|
|
@@ -174,7 +205,19 @@ class SasRec(lightning.LightningModule):
|
|
|
174
205
|
|
|
175
206
|
:returns: Calculated scores.
|
|
176
207
|
"""
|
|
177
|
-
|
|
208
|
+
if isinstance(batch, SasRecValidationBatch):
|
|
209
|
+
warnings.warn(
|
|
210
|
+
"`SasRecValidationBatch` class will be removed in future versions. "
|
|
211
|
+
"Instead, you should use simple dictionary",
|
|
212
|
+
DeprecationWarning,
|
|
213
|
+
stacklevel=2,
|
|
214
|
+
)
|
|
215
|
+
batch = batch.convert_to_dict()
|
|
216
|
+
|
|
217
|
+
return self._model_predict(
|
|
218
|
+
feature_tensors=batch["feature_tensor"],
|
|
219
|
+
padding_mask=batch["padding_mask"],
|
|
220
|
+
)
|
|
178
221
|
|
|
179
222
|
def configure_optimizers(self) -> Any:
|
|
180
223
|
"""
|
|
@@ -197,10 +240,14 @@ class SasRec(lightning.LightningModule):
|
|
|
197
240
|
model: SasRecModel
|
|
198
241
|
model = cast(SasRecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
|
|
199
242
|
candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
|
|
200
|
-
scores = model.predict(
|
|
243
|
+
scores = model.predict(
|
|
244
|
+
feature_tensor=feature_tensors,
|
|
245
|
+
padding_mask=padding_mask,
|
|
246
|
+
candidates_to_score=candidates_to_score,
|
|
247
|
+
)
|
|
201
248
|
return scores
|
|
202
249
|
|
|
203
|
-
def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
|
|
250
|
+
def _compute_loss(self, batch: Union[SasRecTrainingBatch, dict]) -> torch.Tensor:
|
|
204
251
|
if self._loss_type == "BCE":
|
|
205
252
|
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
|
|
206
253
|
elif self._loss_type == "CE":
|
|
@@ -211,11 +258,20 @@ class SasRec(lightning.LightningModule):
|
|
|
211
258
|
msg = f"Not supported loss type: {self._loss_type}"
|
|
212
259
|
raise ValueError(msg)
|
|
213
260
|
|
|
261
|
+
if isinstance(batch, SasRecTrainingBatch):
|
|
262
|
+
warnings.warn(
|
|
263
|
+
"`SasRecTrainingBatch` class will be removed in future versions. "
|
|
264
|
+
"Instead, you should use simple dictionary",
|
|
265
|
+
DeprecationWarning,
|
|
266
|
+
stacklevel=2,
|
|
267
|
+
)
|
|
268
|
+
batch = batch.convert_to_dict()
|
|
269
|
+
|
|
214
270
|
loss = loss_func(
|
|
215
|
-
batch
|
|
216
|
-
batch
|
|
217
|
-
batch
|
|
218
|
-
batch
|
|
271
|
+
batch["feature_tensor"],
|
|
272
|
+
batch["positive_labels"],
|
|
273
|
+
batch["padding_mask"],
|
|
274
|
+
batch["target_padding_mask"],
|
|
219
275
|
)
|
|
220
276
|
return loss
|
|
221
277
|
|
|
@@ -258,7 +314,7 @@ class SasRec(lightning.LightningModule):
|
|
|
258
314
|
padding_mask: torch.BoolTensor,
|
|
259
315
|
target_padding_mask: torch.BoolTensor,
|
|
260
316
|
) -> torch.Tensor:
|
|
261
|
-
|
|
317
|
+
positive_logits, negative_logits, *_ = self._get_sampled_logits(
|
|
262
318
|
feature_tensors, positive_labels, padding_mask, target_padding_mask
|
|
263
319
|
)
|
|
264
320
|
|
|
@@ -306,7 +362,7 @@ class SasRec(lightning.LightningModule):
|
|
|
306
362
|
target_padding_mask: torch.BoolTensor,
|
|
307
363
|
) -> torch.Tensor:
|
|
308
364
|
assert self._loss_sample_count is not None
|
|
309
|
-
|
|
365
|
+
positive_logits, negative_logits, positive_labels, negative_labels, vocab_size = self._get_sampled_logits(
|
|
310
366
|
feature_tensors, positive_labels, padding_mask, target_padding_mask
|
|
311
367
|
)
|
|
312
368
|
n_negative_samples = min(self._loss_sample_count, vocab_size)
|
|
@@ -566,19 +622,23 @@ class SasRec(lightning.LightningModule):
|
|
|
566
622
|
|
|
567
623
|
|
|
568
624
|
def _prepare_prediction_batch(
|
|
569
|
-
schema: TensorSchema,
|
|
570
|
-
|
|
571
|
-
|
|
625
|
+
schema: TensorSchema,
|
|
626
|
+
max_len: int,
|
|
627
|
+
batch: dict,
|
|
628
|
+
) -> dict:
|
|
629
|
+
seq_len = batch["padding_mask"].shape[1]
|
|
630
|
+
if seq_len > max_len:
|
|
572
631
|
msg = (
|
|
573
632
|
"The length of the submitted sequence "
|
|
574
633
|
"must not exceed the maximum length of the sequence. "
|
|
575
|
-
f"The length of the sequence is given {
|
|
634
|
+
f"The length of the sequence is given {seq_len}, "
|
|
576
635
|
f"while the maximum length is {max_len}"
|
|
577
636
|
)
|
|
578
637
|
raise ValueError(msg)
|
|
579
638
|
|
|
580
|
-
if
|
|
581
|
-
|
|
639
|
+
if seq_len < max_len:
|
|
640
|
+
padding_mask = batch["padding_mask"]
|
|
641
|
+
features = batch["feature_tensor"].copy()
|
|
582
642
|
sequence_item_count = padding_mask.shape[1]
|
|
583
643
|
for feature_name, feature_tensor in features.items():
|
|
584
644
|
if schema[feature_name].is_cat:
|
|
@@ -592,5 +652,7 @@ def _prepare_prediction_batch(
|
|
|
592
652
|
value=0,
|
|
593
653
|
).unsqueeze(-1)
|
|
594
654
|
padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
|
|
595
|
-
batch =
|
|
655
|
+
batch["padding_mask"] = padding_mask
|
|
656
|
+
batch["feature_tensor"] = features
|
|
657
|
+
|
|
596
658
|
return batch
|