replay-rec 0.18.0__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +27 -1
- replay/data/dataset_utils/dataset_label_encoder.py +6 -3
- replay/data/nn/schema.py +37 -16
- replay/data/nn/sequence_tokenizer.py +313 -165
- replay/data/nn/torch_sequential_dataset.py +17 -8
- replay/data/nn/utils.py +14 -7
- replay/data/schema.py +10 -6
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +1 -0
- replay/models/base_rec.py +18 -21
- replay/models/lin_ucb.py +407 -0
- replay/models/nn/sequential/bert4rec/dataset.py +17 -4
- replay/models/nn/sequential/bert4rec/lightning.py +121 -54
- replay/models/nn/sequential/bert4rec/model.py +21 -0
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +5 -1
- replay/models/nn/sequential/compiled/__init__.py +5 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +261 -0
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +152 -0
- replay/models/nn/sequential/compiled/sasrec_compiled.py +145 -0
- replay/models/nn/sequential/postprocessors/postprocessors.py +27 -1
- replay/models/nn/sequential/sasrec/dataset.py +17 -1
- replay/models/nn/sequential/sasrec/lightning.py +126 -50
- replay/models/nn/sequential/sasrec/model.py +3 -4
- replay/preprocessing/__init__.py +7 -1
- replay/preprocessing/discretizer.py +719 -0
- replay/preprocessing/label_encoder.py +384 -52
- replay/splitters/cold_user_random_splitter.py +1 -1
- replay/utils/__init__.py +1 -0
- replay/utils/common.py +7 -8
- replay/utils/session_handler.py +3 -4
- replay/utils/spark_utils.py +15 -1
- replay/utils/types.py +8 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/METADATA +73 -60
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/RECORD +37 -31
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/LICENSE +0 -0
- {replay_rec-0.18.0.dist-info → replay_rec-0.18.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
from typing import Optional, Union, get_args
|
|
3
|
+
|
|
4
|
+
import openvino as ov
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn import TensorSchema
|
|
8
|
+
from replay.models.nn.sequential.bert4rec import (
|
|
9
|
+
Bert4Rec,
|
|
10
|
+
Bert4RecPredictionBatch,
|
|
11
|
+
)
|
|
12
|
+
from replay.models.nn.sequential.bert4rec.lightning import _prepare_prediction_batch
|
|
13
|
+
from replay.models.nn.sequential.compiled.base_compiled_model import (
|
|
14
|
+
BaseCompiledModel,
|
|
15
|
+
OptimizedModeType,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Bert4RecCompiled(BaseCompiledModel):
|
|
20
|
+
"""
|
|
21
|
+
Bert4Rec CPU-optimized model for inference via OpenVINO.
|
|
22
|
+
It is recommended to compile model with ``compile`` method and pass ``Bert4Rec`` checkpoint
|
|
23
|
+
or the model object itself into it.
|
|
24
|
+
It is also possible to compile model by yourself and pass it to the ``__init__`` with ``TensorSchema``.
|
|
25
|
+
|
|
26
|
+
**Note** that compilation requires disk write (and maybe delete) permission.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
compiled_model: ov.CompiledModel,
|
|
32
|
+
schema: TensorSchema,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
:param compiled_model: Compiled model.
|
|
36
|
+
:param schema: Tensor schema of Bert4Rec model.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(compiled_model, schema)
|
|
39
|
+
|
|
40
|
+
def predict(
|
|
41
|
+
self,
|
|
42
|
+
batch: Bert4RecPredictionBatch,
|
|
43
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
44
|
+
) -> torch.Tensor:
|
|
45
|
+
"""
|
|
46
|
+
Inference on one batch.
|
|
47
|
+
|
|
48
|
+
:param batch: Prediction input.
|
|
49
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
50
|
+
Default: ``None``.
|
|
51
|
+
|
|
52
|
+
:return: Tensor with scores.
|
|
53
|
+
"""
|
|
54
|
+
self._valilade_predict_input(batch, candidates_to_score)
|
|
55
|
+
|
|
56
|
+
batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
|
|
57
|
+
model_inputs = {
|
|
58
|
+
self._inputs_names[0]: batch.features[self._inputs_names[0]],
|
|
59
|
+
self._inputs_names[1]: batch.padding_mask,
|
|
60
|
+
self._inputs_names[2]: batch.tokens_mask,
|
|
61
|
+
}
|
|
62
|
+
if self._num_candidates_to_score is not None:
|
|
63
|
+
self._validate_candidates_to_score(candidates_to_score)
|
|
64
|
+
model_inputs[self._inputs_names[3]] = candidates_to_score
|
|
65
|
+
return torch.from_numpy(self._model(model_inputs)[self._output_name])
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def compile(
|
|
69
|
+
cls,
|
|
70
|
+
model: Union[Bert4Rec, str, pathlib.Path],
|
|
71
|
+
mode: OptimizedModeType = "one_query",
|
|
72
|
+
batch_size: Optional[int] = None,
|
|
73
|
+
num_candidates_to_score: Optional[int] = None,
|
|
74
|
+
num_threads: Optional[int] = None,
|
|
75
|
+
onnx_path: Optional[str] = None,
|
|
76
|
+
) -> "Bert4RecCompiled":
|
|
77
|
+
"""
|
|
78
|
+
Model compilation.
|
|
79
|
+
|
|
80
|
+
:param model: Path to lightning Bert4Rec model saved in .ckpt format or the Bert4Rec object itself.
|
|
81
|
+
:param mode: Inference mode, defines shape of inputs.
|
|
82
|
+
Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
|
|
83
|
+
``one_query`` - sets input shape to [1, max_seq_len]\n
|
|
84
|
+
``batch`` - sets input shape to [batch_size, max_seq_len]\n
|
|
85
|
+
``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
|
|
86
|
+
Default: ``one_query``.
|
|
87
|
+
:param batch_size: Batch size, required for ``batch`` mode.
|
|
88
|
+
Default: ``None``.
|
|
89
|
+
:param num_candidates_to_score: Number of item ids to calculate scores.
|
|
90
|
+
Could be one of [``None``, ``-1``, ``N``].\n
|
|
91
|
+
``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
|
|
92
|
+
``N`` - sets candidates_to_score shape to [1, N]\n
|
|
93
|
+
``None`` - disables candidates_to_score usage\n
|
|
94
|
+
Default: ``None``.
|
|
95
|
+
:param num_threads: Number of CPU threads to use.
|
|
96
|
+
Must be a natural number or ``None``.
|
|
97
|
+
If ``None``, then compiler will set this parameter automatically.
|
|
98
|
+
Default: ``None``.
|
|
99
|
+
:param onnx_path: Save ONNX model to path, if defined.
|
|
100
|
+
Default: ``None``.
|
|
101
|
+
"""
|
|
102
|
+
if mode not in get_args(OptimizedModeType):
|
|
103
|
+
msg = f"Parameter ``mode`` could be one of {get_args(OptimizedModeType)}."
|
|
104
|
+
raise ValueError(msg)
|
|
105
|
+
num_candidates_to_score = Bert4RecCompiled._validate_num_candidates_to_score(num_candidates_to_score)
|
|
106
|
+
if isinstance(model, Bert4Rec):
|
|
107
|
+
lightning_model = model.cpu()
|
|
108
|
+
elif isinstance(model, (str, pathlib.Path)):
|
|
109
|
+
lightning_model = Bert4Rec.load_from_checkpoint(model, map_location=torch.device("cpu"))
|
|
110
|
+
|
|
111
|
+
schema = lightning_model._schema
|
|
112
|
+
item_seq_name = schema.item_id_feature_name
|
|
113
|
+
max_seq_len = lightning_model._model.max_len
|
|
114
|
+
|
|
115
|
+
batch_size, num_candidates_to_score = Bert4RecCompiled._get_input_params(
|
|
116
|
+
mode, batch_size, num_candidates_to_score
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
item_sequence = torch.zeros((1, max_seq_len)).long()
|
|
120
|
+
padding_mask = torch.zeros((1, max_seq_len)).bool()
|
|
121
|
+
tokens_mask = torch.zeros((1, max_seq_len)).bool()
|
|
122
|
+
|
|
123
|
+
model_input_names = [item_seq_name, "padding_mask", "tokens_mask"]
|
|
124
|
+
model_dynamic_axes_in_input = {
|
|
125
|
+
item_seq_name: {0: "batch_size", 1: "max_len"},
|
|
126
|
+
"padding_mask": {0: "batch_size", 1: "max_len"},
|
|
127
|
+
"tokens_mask": {0: "batch_size", 1: "max_len"},
|
|
128
|
+
}
|
|
129
|
+
if num_candidates_to_score:
|
|
130
|
+
candidates_to_score = torch.zeros((1,)).long()
|
|
131
|
+
model_input_names += ["candidates_to_score"]
|
|
132
|
+
model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
|
|
133
|
+
model_input_sample = ({item_seq_name: item_sequence}, padding_mask, tokens_mask, candidates_to_score)
|
|
134
|
+
else:
|
|
135
|
+
model_input_sample = ({item_seq_name: item_sequence}, padding_mask, tokens_mask)
|
|
136
|
+
|
|
137
|
+
# Need to disable "Better Transformer" optimizations that interfere with the compilation process
|
|
138
|
+
if hasattr(torch.backends, "mha"):
|
|
139
|
+
torch.backends.mha.set_fastpath_enabled(value=False)
|
|
140
|
+
|
|
141
|
+
compiled_model = Bert4RecCompiled._run_model_compilation(
|
|
142
|
+
lightning_model,
|
|
143
|
+
model_input_sample,
|
|
144
|
+
model_input_names,
|
|
145
|
+
model_dynamic_axes_in_input,
|
|
146
|
+
batch_size,
|
|
147
|
+
num_candidates_to_score,
|
|
148
|
+
num_threads,
|
|
149
|
+
onnx_path,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return cls(compiled_model, schema)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import pathlib
|
|
2
|
+
from typing import Optional, Union, get_args
|
|
3
|
+
|
|
4
|
+
import openvino as ov
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from replay.data.nn import TensorSchema
|
|
8
|
+
from replay.models.nn.sequential.compiled.base_compiled_model import (
|
|
9
|
+
BaseCompiledModel,
|
|
10
|
+
OptimizedModeType,
|
|
11
|
+
)
|
|
12
|
+
from replay.models.nn.sequential.sasrec import (
|
|
13
|
+
SasRec,
|
|
14
|
+
SasRecPredictionBatch,
|
|
15
|
+
)
|
|
16
|
+
from replay.models.nn.sequential.sasrec.lightning import _prepare_prediction_batch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SasRecCompiled(BaseCompiledModel):
|
|
20
|
+
"""
|
|
21
|
+
SasRec CPU-optimized model for inference via OpenVINO.
|
|
22
|
+
It is recommended to compile model with ``compile`` method and pass ``SasRec`` checkpoint
|
|
23
|
+
or the model object itself into it.
|
|
24
|
+
It is also possible to compile model by yourself and pass it to the ``__init__`` with ``TensorSchema``.
|
|
25
|
+
|
|
26
|
+
**Note** that compilation requires disk write (and maybe delete) permission.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
compiled_model: ov.CompiledModel,
|
|
32
|
+
schema: TensorSchema,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
:param compiled_model: Compiled model.
|
|
36
|
+
:param schema: Tensor schema of SasRec model.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(compiled_model, schema)
|
|
39
|
+
|
|
40
|
+
def predict(
|
|
41
|
+
self,
|
|
42
|
+
batch: SasRecPredictionBatch,
|
|
43
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
44
|
+
) -> torch.Tensor:
|
|
45
|
+
"""
|
|
46
|
+
Inference on one batch.
|
|
47
|
+
|
|
48
|
+
:param batch: Prediction input.
|
|
49
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
50
|
+
Default: ``None``.
|
|
51
|
+
|
|
52
|
+
:return: Tensor with scores.
|
|
53
|
+
"""
|
|
54
|
+
self._valilade_predict_input(batch, candidates_to_score)
|
|
55
|
+
|
|
56
|
+
batch = _prepare_prediction_batch(self._schema, self._max_seq_len, batch)
|
|
57
|
+
model_inputs = {
|
|
58
|
+
self._inputs_names[0]: batch.features[self._inputs_names[0]],
|
|
59
|
+
self._inputs_names[1]: batch.padding_mask,
|
|
60
|
+
}
|
|
61
|
+
if self._num_candidates_to_score is not None:
|
|
62
|
+
self._validate_candidates_to_score(candidates_to_score)
|
|
63
|
+
model_inputs[self._inputs_names[2]] = candidates_to_score
|
|
64
|
+
return torch.from_numpy(self._model(model_inputs)[self._output_name])
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def compile(
|
|
68
|
+
cls,
|
|
69
|
+
model: Union[SasRec, str, pathlib.Path],
|
|
70
|
+
mode: OptimizedModeType = "one_query",
|
|
71
|
+
batch_size: Optional[int] = None,
|
|
72
|
+
num_candidates_to_score: Optional[int] = None,
|
|
73
|
+
num_threads: Optional[int] = None,
|
|
74
|
+
onnx_path: Optional[str] = None,
|
|
75
|
+
) -> "SasRecCompiled":
|
|
76
|
+
"""
|
|
77
|
+
Model compilation.
|
|
78
|
+
|
|
79
|
+
:param model: Path to lightning SasRec model saved in .ckpt format or the SasRec object itself.
|
|
80
|
+
:param mode: Inference mode, defines shape of inputs.
|
|
81
|
+
Could be one of [``one_query``, ``batch``, ``dynamic_batch_size``].\n
|
|
82
|
+
``one_query`` - sets input shape to [1, max_seq_len]\n
|
|
83
|
+
``batch`` - sets input shape to [batch_size, max_seq_len]\n
|
|
84
|
+
``dynamic_batch_size`` - sets batch_size to dynamic range [?, max_seq_len]\n
|
|
85
|
+
Default: ``one_query``.
|
|
86
|
+
:param batch_size: Batch size, required for ``batch`` mode.
|
|
87
|
+
Default: ``None``.
|
|
88
|
+
:param num_candidates_to_score: Number of item ids to calculate scores.
|
|
89
|
+
Could be one of [``None``, ``-1``, ``N``].\n
|
|
90
|
+
``-1`` - sets candidates_to_score shape to dynamic range [1, ?]\n
|
|
91
|
+
``N`` - sets candidates_to_score shape to [1, N]\n
|
|
92
|
+
``None`` - disable candidates_to_score usage\n
|
|
93
|
+
Default: ``None``.
|
|
94
|
+
:param num_threads: Number of CPU threads to use.
|
|
95
|
+
Must be a natural number or ``None``.
|
|
96
|
+
If ``None``, then compiler will set this parameter automatically.
|
|
97
|
+
Default: ``None``.
|
|
98
|
+
:param onnx_path: Save ONNX model to path, if defined.
|
|
99
|
+
Default: ``None``.
|
|
100
|
+
"""
|
|
101
|
+
if mode not in get_args(OptimizedModeType):
|
|
102
|
+
msg = f"Parameter ``mode`` could be one of {get_args(OptimizedModeType)}."
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
num_candidates_to_score = SasRecCompiled._validate_num_candidates_to_score(num_candidates_to_score)
|
|
105
|
+
if isinstance(model, SasRec):
|
|
106
|
+
lightning_model = model.cpu()
|
|
107
|
+
elif isinstance(model, (str, pathlib.Path)):
|
|
108
|
+
lightning_model = SasRec.load_from_checkpoint(model, map_location=torch.device("cpu"))
|
|
109
|
+
|
|
110
|
+
schema = lightning_model._schema
|
|
111
|
+
item_seq_name = schema.item_id_feature_name
|
|
112
|
+
max_seq_len = lightning_model._model.max_len
|
|
113
|
+
|
|
114
|
+
batch_size, num_candidates_to_score = SasRecCompiled._get_input_params(
|
|
115
|
+
mode, batch_size, num_candidates_to_score
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
item_sequence = torch.zeros((1, max_seq_len)).long()
|
|
119
|
+
padding_mask = torch.zeros((1, max_seq_len)).bool()
|
|
120
|
+
|
|
121
|
+
model_input_names = [item_seq_name, "padding_mask"]
|
|
122
|
+
model_dynamic_axes_in_input = {
|
|
123
|
+
item_seq_name: {0: "batch_size", 1: "max_len"},
|
|
124
|
+
"padding_mask": {0: "batch_size", 1: "max_len"},
|
|
125
|
+
}
|
|
126
|
+
if num_candidates_to_score:
|
|
127
|
+
candidates_to_score = torch.zeros((1,)).long()
|
|
128
|
+
model_input_names += ["candidates_to_score"]
|
|
129
|
+
model_dynamic_axes_in_input["candidates_to_score"] = {0: "num_candidates_to_score"}
|
|
130
|
+
model_input_sample = ({item_seq_name: item_sequence}, padding_mask, candidates_to_score)
|
|
131
|
+
else:
|
|
132
|
+
model_input_sample = ({item_seq_name: item_sequence}, padding_mask)
|
|
133
|
+
|
|
134
|
+
compiled_model = SasRecCompiled._run_model_compilation(
|
|
135
|
+
lightning_model,
|
|
136
|
+
model_input_sample,
|
|
137
|
+
model_input_names,
|
|
138
|
+
model_dynamic_axes_in_input,
|
|
139
|
+
batch_size,
|
|
140
|
+
num_candidates_to_score,
|
|
141
|
+
num_threads,
|
|
142
|
+
onnx_path,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return cls(compiled_model, schema)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Optional, Set, Tuple, cast
|
|
1
|
+
from typing import List, Optional, Set, Tuple, Union, cast
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
@@ -17,6 +17,8 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
17
17
|
def __init__(self, sequential: SequentialDataset) -> None:
|
|
18
18
|
super().__init__()
|
|
19
19
|
self._sequential = sequential
|
|
20
|
+
self._apply_candidates = False
|
|
21
|
+
self._candidates = None
|
|
20
22
|
|
|
21
23
|
def on_validation(
|
|
22
24
|
self, query_ids: torch.LongTensor, scores: torch.Tensor, ground_truth: torch.LongTensor
|
|
@@ -30,6 +32,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
30
32
|
|
|
31
33
|
:returns: modified query ids and scores and ground truth dataset
|
|
32
34
|
"""
|
|
35
|
+
self._apply_candidates = False
|
|
33
36
|
modified_scores = self._compute_scores(query_ids, scores)
|
|
34
37
|
return query_ids, modified_scores, ground_truth
|
|
35
38
|
|
|
@@ -42,6 +45,7 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
42
45
|
|
|
43
46
|
:returns: modified query ids and scores
|
|
44
47
|
"""
|
|
48
|
+
self._apply_candidates = True
|
|
45
49
|
modified_scores = self._compute_scores(query_ids, scores)
|
|
46
50
|
return query_ids, modified_scores
|
|
47
51
|
|
|
@@ -56,6 +60,13 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
56
60
|
value: float,
|
|
57
61
|
) -> torch.Tensor:
|
|
58
62
|
flat_item_ids_on_device = flat_item_ids.to(scores.device)
|
|
63
|
+
|
|
64
|
+
if self._apply_candidates and self._candidates is not None:
|
|
65
|
+
item_count = self._sequential.schema.item_id_features.item().cardinality
|
|
66
|
+
assert item_count
|
|
67
|
+
_scores = torch.full((scores.shape[0], item_count), -float("inf")).to(scores.device)
|
|
68
|
+
_scores[:, self._candidates] = torch.reshape(scores, _scores[:, self._candidates].shape)
|
|
69
|
+
scores = _scores
|
|
59
70
|
if scores.is_contiguous():
|
|
60
71
|
scores.view(-1)[flat_item_ids_on_device] = value
|
|
61
72
|
else:
|
|
@@ -80,6 +91,21 @@ class RemoveSeenItems(BasePostProcessor):
|
|
|
80
91
|
flat_seen_item_ids_np = np.concatenate(item_id_sequences)
|
|
81
92
|
return torch.LongTensor(flat_seen_item_ids_np)
|
|
82
93
|
|
|
94
|
+
@property
|
|
95
|
+
def candidates(self) -> Union[torch.LongTensor, None]:
|
|
96
|
+
"""
|
|
97
|
+
Returns tensor of item ids to calculate scores.
|
|
98
|
+
"""
|
|
99
|
+
return self._candidates
|
|
100
|
+
|
|
101
|
+
@candidates.setter
|
|
102
|
+
def candidates(self, candidates: Optional[torch.LongTensor] = None) -> None:
|
|
103
|
+
"""
|
|
104
|
+
Sets tensor of item ids to calculate scores.
|
|
105
|
+
:param candidates: Tensor of item ids to calculate scores.
|
|
106
|
+
"""
|
|
107
|
+
self._candidates = candidates
|
|
108
|
+
|
|
83
109
|
|
|
84
110
|
class SampleItems(BasePostProcessor):
|
|
85
111
|
"""
|
|
@@ -10,6 +10,7 @@ from replay.data.nn import (
|
|
|
10
10
|
TorchSequentialDataset,
|
|
11
11
|
TorchSequentialValidationDataset,
|
|
12
12
|
)
|
|
13
|
+
from replay.utils.model_handler import deprecation_warning
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class SasRecTrainingBatch(NamedTuple):
|
|
@@ -30,6 +31,10 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
30
31
|
Dataset that generates samples to train SasRec-like model
|
|
31
32
|
"""
|
|
32
33
|
|
|
34
|
+
@deprecation_warning(
|
|
35
|
+
"`padding_value` parameter will be removed in future versions. "
|
|
36
|
+
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
37
|
+
)
|
|
33
38
|
def __init__(
|
|
34
39
|
self,
|
|
35
40
|
sequential: SequentialDataset,
|
|
@@ -90,7 +95,10 @@ class SasRecTrainingDataset(TorchDataset):
|
|
|
90
95
|
|
|
91
96
|
output_features: MutableTensorMap = {}
|
|
92
97
|
for feature_name in self._schema:
|
|
93
|
-
|
|
98
|
+
feature = features[feature_name]
|
|
99
|
+
if self._schema[feature_name].is_seq:
|
|
100
|
+
feature = feature[: -self._sequence_shift]
|
|
101
|
+
output_features[feature_name] = feature
|
|
94
102
|
|
|
95
103
|
output_features_padding_mask = padding_mask[: -self._sequence_shift]
|
|
96
104
|
|
|
@@ -119,6 +127,10 @@ class SasRecPredictionDataset(TorchDataset):
|
|
|
119
127
|
Dataset that generates samples to infer SasRec-like model
|
|
120
128
|
"""
|
|
121
129
|
|
|
130
|
+
@deprecation_warning(
|
|
131
|
+
"`padding_value` parameter will be removed in future versions. "
|
|
132
|
+
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
133
|
+
)
|
|
122
134
|
def __init__(
|
|
123
135
|
self,
|
|
124
136
|
sequential: SequentialDataset,
|
|
@@ -167,6 +179,10 @@ class SasRecValidationDataset(TorchDataset):
|
|
|
167
179
|
Dataset that generates samples to infer and validate SasRec-like model
|
|
168
180
|
"""
|
|
169
181
|
|
|
182
|
+
@deprecation_warning(
|
|
183
|
+
"`padding_value` parameter will be removed in future versions. "
|
|
184
|
+
"Instead, you should specify `padding_value` for each column in TensorSchema"
|
|
185
|
+
)
|
|
170
186
|
def __init__(
|
|
171
187
|
self,
|
|
172
188
|
sequential: SequentialDataset,
|
|
@@ -33,7 +33,7 @@ class SasRec(lightning.LightningModule):
|
|
|
33
33
|
loss_sample_count: Optional[int] = None,
|
|
34
34
|
negative_sampling_strategy: str = "global_uniform",
|
|
35
35
|
negatives_sharing: bool = False,
|
|
36
|
-
optimizer_factory:
|
|
36
|
+
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
|
|
37
37
|
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
|
|
38
38
|
):
|
|
39
39
|
"""
|
|
@@ -63,7 +63,7 @@ class SasRec(lightning.LightningModule):
|
|
|
63
63
|
:param negatives_sharing: Apply negative sharing in calculating sampled logits.
|
|
64
64
|
Default: ``False``.
|
|
65
65
|
:param optimizer_factory: Optimizer factory.
|
|
66
|
-
Default: ``
|
|
66
|
+
Default: ``FatOptimizerFactory``.
|
|
67
67
|
:param lr_scheduler_factory: Learning rate schedule factory.
|
|
68
68
|
Default: ``None``.
|
|
69
69
|
"""
|
|
@@ -92,6 +92,7 @@ class SasRec(lightning.LightningModule):
|
|
|
92
92
|
item_count = tensor_schema.item_id_features.item().cardinality
|
|
93
93
|
assert item_count
|
|
94
94
|
self._vocab_size = item_count
|
|
95
|
+
self.candidates_to_score = None
|
|
95
96
|
|
|
96
97
|
def training_step(self, batch: SasRecTrainingBatch, batch_idx: int) -> torch.Tensor:
|
|
97
98
|
"""
|
|
@@ -106,30 +107,58 @@ class SasRec(lightning.LightningModule):
|
|
|
106
107
|
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
|
|
107
108
|
return loss
|
|
108
109
|
|
|
109
|
-
def
|
|
110
|
+
def predict_step(
|
|
111
|
+
self,
|
|
112
|
+
batch: SasRecPredictionBatch,
|
|
113
|
+
batch_idx: int, # noqa: ARG002
|
|
114
|
+
dataloader_idx: int = 0, # noqa: ARG002
|
|
115
|
+
) -> torch.Tensor:
|
|
110
116
|
"""
|
|
111
|
-
:param
|
|
112
|
-
:param
|
|
117
|
+
:param batch: Batch of prediction data.
|
|
118
|
+
:param batch_idx: Batch index.
|
|
119
|
+
:param dataloader_idx: Dataloader index.
|
|
113
120
|
|
|
114
121
|
:returns: Calculated scores.
|
|
115
122
|
"""
|
|
116
|
-
|
|
123
|
+
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
124
|
+
return self._model_predict(batch.features, batch.padding_mask)
|
|
117
125
|
|
|
118
|
-
def
|
|
119
|
-
self,
|
|
126
|
+
def predict(
|
|
127
|
+
self,
|
|
128
|
+
batch: SasRecPredictionBatch,
|
|
129
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
120
130
|
) -> torch.Tensor:
|
|
121
131
|
"""
|
|
122
132
|
:param batch: Batch of prediction data.
|
|
123
|
-
:param
|
|
124
|
-
|
|
133
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
134
|
+
Default: ``None``.
|
|
125
135
|
|
|
126
136
|
:returns: Calculated scores.
|
|
127
137
|
"""
|
|
128
|
-
batch = self.
|
|
129
|
-
return self._model_predict(batch.features, batch.padding_mask)
|
|
138
|
+
batch = _prepare_prediction_batch(self._schema, self._model.max_len, batch)
|
|
139
|
+
return self._model_predict(batch.features, batch.padding_mask, candidates_to_score)
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self,
|
|
143
|
+
feature_tensors: TensorMap,
|
|
144
|
+
padding_mask: torch.BoolTensor,
|
|
145
|
+
candidates_to_score: Optional[torch.LongTensor] = None,
|
|
146
|
+
) -> torch.Tensor: # pragma: no cover
|
|
147
|
+
"""
|
|
148
|
+
:param feature_tensors: Batch of features.
|
|
149
|
+
:param padding_mask: Padding mask where 0 - <PAD>, 1 otherwise.
|
|
150
|
+
:param candidates_to_score: Item ids to calculate scores.
|
|
151
|
+
Default: ``None``.
|
|
152
|
+
|
|
153
|
+
:returns: Calculated scores.
|
|
154
|
+
"""
|
|
155
|
+
return self._model_predict(feature_tensors, padding_mask, candidates_to_score)
|
|
130
156
|
|
|
131
157
|
def validation_step(
|
|
132
|
-
self,
|
|
158
|
+
self,
|
|
159
|
+
batch: SasRecValidationBatch,
|
|
160
|
+
batch_idx: int, # noqa: ARG002
|
|
161
|
+
dataloader_idx: int = 0, # noqa: ARG002
|
|
133
162
|
) -> torch.Tensor:
|
|
134
163
|
"""
|
|
135
164
|
:param batch (SasRecValidationBatch): Batch of prediction data.
|
|
@@ -143,8 +172,7 @@ class SasRec(lightning.LightningModule):
|
|
|
143
172
|
"""
|
|
144
173
|
:returns: Configured optimizer and lr scheduler.
|
|
145
174
|
"""
|
|
146
|
-
|
|
147
|
-
optimizer = optimizer_factory.create(self._model.parameters())
|
|
175
|
+
optimizer = self._optimizer_factory.create(self._model.parameters())
|
|
148
176
|
|
|
149
177
|
if self._lr_scheduler_factory is None:
|
|
150
178
|
return optimizer
|
|
@@ -152,38 +180,16 @@ class SasRec(lightning.LightningModule):
|
|
|
152
180
|
lr_scheduler = self._lr_scheduler_factory.create(optimizer)
|
|
153
181
|
return [optimizer], [lr_scheduler]
|
|
154
182
|
|
|
155
|
-
def
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
raise ValueError(msg)
|
|
162
|
-
|
|
163
|
-
if batch.padding_mask.shape[1] < self._model.max_len:
|
|
164
|
-
query_id, padding_mask, features = batch
|
|
165
|
-
sequence_item_count = padding_mask.shape[1]
|
|
166
|
-
for feature_name, feature_tensor in features.items():
|
|
167
|
-
if self._schema[feature_name].is_cat:
|
|
168
|
-
features[feature_name] = torch.nn.functional.pad(
|
|
169
|
-
feature_tensor, (self._model.max_len - sequence_item_count, 0), value=0
|
|
170
|
-
)
|
|
171
|
-
else:
|
|
172
|
-
features[feature_name] = torch.nn.functional.pad(
|
|
173
|
-
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
|
|
174
|
-
(self._model.max_len - sequence_item_count, 0),
|
|
175
|
-
value=0,
|
|
176
|
-
).unsqueeze(-1)
|
|
177
|
-
padding_mask = torch.nn.functional.pad(
|
|
178
|
-
padding_mask, (self._model.max_len - sequence_item_count, 0), value=0
|
|
179
|
-
)
|
|
180
|
-
batch = SasRecPredictionBatch(query_id, padding_mask, features)
|
|
181
|
-
return batch
|
|
182
|
-
|
|
183
|
-
def _model_predict(self, feature_tensors: TensorMap, padding_mask: torch.BoolTensor) -> torch.Tensor:
|
|
183
|
+
def _model_predict(
|
|
184
|
+
self,
|
|
185
|
+
feature_tensors: TensorMap,
|
|
186
|
+
padding_mask: torch.BoolTensor,
|
|
187
|
+
candidates_to_score: torch.LongTensor = None,
|
|
188
|
+
) -> torch.Tensor:
|
|
184
189
|
model: SasRecModel
|
|
185
190
|
model = cast(SasRecModel, self._model.module) if isinstance(self._model, torch.nn.DataParallel) else self._model
|
|
186
|
-
|
|
191
|
+
candidates_to_score = self.candidates_to_score if candidates_to_score is None else candidates_to_score
|
|
192
|
+
scores = model.predict(feature_tensors, padding_mask, candidates_to_score)
|
|
187
193
|
return scores
|
|
188
194
|
|
|
189
195
|
def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
|
|
@@ -479,6 +485,50 @@ class SasRec(lightning.LightningModule):
|
|
|
479
485
|
|
|
480
486
|
self._set_new_item_embedder_to_model(new_embedding, new_vocab_size)
|
|
481
487
|
|
|
488
|
+
@property
|
|
489
|
+
def optimizer_factory(self) -> OptimizerFactory:
|
|
490
|
+
"""
|
|
491
|
+
Returns current optimizer_factory.
|
|
492
|
+
"""
|
|
493
|
+
return self._optimizer_factory
|
|
494
|
+
|
|
495
|
+
@optimizer_factory.setter
|
|
496
|
+
def optimizer_factory(self, optimizer_factory: OptimizerFactory) -> None:
|
|
497
|
+
"""
|
|
498
|
+
Sets new optimizer_factory.
|
|
499
|
+
:param optimizer_factory: New optimizer factory.
|
|
500
|
+
"""
|
|
501
|
+
if isinstance(optimizer_factory, OptimizerFactory):
|
|
502
|
+
self._optimizer_factory = optimizer_factory
|
|
503
|
+
else:
|
|
504
|
+
msg = f"Expected optimizer_factory of type OptimizerFactory, got {type(optimizer_factory)}"
|
|
505
|
+
raise ValueError(msg)
|
|
506
|
+
|
|
507
|
+
@property
|
|
508
|
+
def candidates_to_score(self) -> Union[torch.LongTensor, None]:
|
|
509
|
+
"""
|
|
510
|
+
Returns tensor of item ids to calculate scores.
|
|
511
|
+
"""
|
|
512
|
+
return self._candidates_to_score
|
|
513
|
+
|
|
514
|
+
@candidates_to_score.setter
|
|
515
|
+
def candidates_to_score(self, candidates: Optional[torch.LongTensor] = None) -> None:
|
|
516
|
+
"""
|
|
517
|
+
Sets tensor of item ids to calculate scores.
|
|
518
|
+
:param candidates: Tensor of item ids to calculate scores.
|
|
519
|
+
"""
|
|
520
|
+
total_item_count = self._model.item_count
|
|
521
|
+
if isinstance(candidates, torch.Tensor) and candidates.dtype is torch.long:
|
|
522
|
+
if 0 < candidates.shape[0] <= total_item_count:
|
|
523
|
+
self._candidates_to_score = candidates
|
|
524
|
+
else:
|
|
525
|
+
msg = f"Expected candidates length to be between 1 and {total_item_count=}"
|
|
526
|
+
raise ValueError(msg)
|
|
527
|
+
elif candidates is not None:
|
|
528
|
+
msg = f"Expected candidates to be of type torch.LongTensor or None, gpt {type(candidates)}"
|
|
529
|
+
raise ValueError(msg)
|
|
530
|
+
self._candidates_to_score = candidates
|
|
531
|
+
|
|
482
532
|
def _set_new_item_embedder_to_model(self, new_embedding: torch.nn.Embedding, new_vocab_size: int):
|
|
483
533
|
self._model.item_embedder.item_emb = new_embedding
|
|
484
534
|
self._model._head._item_embedder = self._model.item_embedder
|
|
@@ -486,11 +536,37 @@ class SasRec(lightning.LightningModule):
|
|
|
486
536
|
self._model.item_count = new_vocab_size
|
|
487
537
|
self._model.padding_idx = new_vocab_size
|
|
488
538
|
self._model.masking.padding_idx = new_vocab_size
|
|
489
|
-
self._model.candidates_to_score = torch.tensor(
|
|
490
|
-
list(range(new_embedding.weight.data.shape[0] - 1)),
|
|
491
|
-
device=self._model.candidates_to_score.device,
|
|
492
|
-
dtype=torch.long,
|
|
493
|
-
)
|
|
494
539
|
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(
|
|
495
540
|
new_embedding.weight.data.shape[0] - 1
|
|
496
541
|
)
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _prepare_prediction_batch(
|
|
545
|
+
schema: TensorSchema, max_len: int, batch: SasRecPredictionBatch
|
|
546
|
+
) -> SasRecPredictionBatch:
|
|
547
|
+
if batch.padding_mask.shape[1] > max_len:
|
|
548
|
+
msg = (
|
|
549
|
+
"The length of the submitted sequence "
|
|
550
|
+
"must not exceed the maximum length of the sequence. "
|
|
551
|
+
f"The length of the sequence is given {batch.padding_mask.shape[1]}, "
|
|
552
|
+
f"while the maximum length is {max_len}"
|
|
553
|
+
)
|
|
554
|
+
raise ValueError(msg)
|
|
555
|
+
|
|
556
|
+
if batch.padding_mask.shape[1] < max_len:
|
|
557
|
+
query_id, padding_mask, features = batch
|
|
558
|
+
sequence_item_count = padding_mask.shape[1]
|
|
559
|
+
for feature_name, feature_tensor in features.items():
|
|
560
|
+
if schema[feature_name].is_cat:
|
|
561
|
+
features[feature_name] = torch.nn.functional.pad(
|
|
562
|
+
feature_tensor, (max_len - sequence_item_count, 0), value=0
|
|
563
|
+
)
|
|
564
|
+
else:
|
|
565
|
+
features[feature_name] = torch.nn.functional.pad(
|
|
566
|
+
feature_tensor.view(feature_tensor.size(0), feature_tensor.size(1)),
|
|
567
|
+
(max_len - sequence_item_count, 0),
|
|
568
|
+
value=0,
|
|
569
|
+
).unsqueeze(-1)
|
|
570
|
+
padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0)
|
|
571
|
+
batch = SasRecPredictionBatch(query_id, padding_mask, features)
|
|
572
|
+
return batch
|